angr.analyses.decompiler.ail_simplifier 源代码

# pylint:disable=too-many-boolean-expressions,consider-using-enumerate
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from collections.abc import Iterable
from collections import defaultdict
import logging

import networkx

from ailment import AILBlockWalker
from ailment.block import Block
from ailment.statement import Statement, Assignment, Store, Call, ConditionalJump, DirtyStatement
from ailment.expression import (
    Register,
    Convert,
    Load,
    StackBaseOffset,
    Expression,
    DirtyExpression,
    VEXCCallExpression,
    Tmp,
    Const,
    BinaryOp,
    VirtualVariable,
)

from angr.analyses.s_propagator import SPropagatorAnalysis
from angr.analyses.s_reaching_definitions import SRDAModel
from angr.utils.ail import is_phi_assignment, HasExprWalker
from angr.code_location import CodeLocation, ExternalCodeLocation
from angr.sim_variable import SimStackVariable, SimMemoryVariable, SimVariable
from angr.knowledge_plugins.propagations.states import Equivalence
from angr.knowledge_plugins.key_definitions import atoms
from angr.knowledge_plugins.key_definitions.definition import Definition
from angr.knowledge_plugins.key_definitions.constants import OP_BEFORE
from angr.errors import AngrRuntimeError
from angr.analyses import Analysis, AnalysesHub
from .ailgraph_walker import AILGraphWalker
from .expression_narrower import ExprNarrowingInfo, NarrowingInfoExtractor, ExpressionNarrower
from .block_simplifier import BlockSimplifier
from .ccall_rewriters import CCALL_REWRITERS
from .counters.expression_counters import SingleExpressionCounter

if TYPE_CHECKING:
    from ailment.manager import Manager


_l = logging.getLogger(__name__)


[文档] class HasCallNotification(Exception): """ Notifies the existence of a call statement. """
[文档] class HasVVarNotification(Exception): """ Notifies the existence of a VirtualVariable. """
[文档] class AILBlockTempCollector(AILBlockWalker): """ Collects any temporaries used in a block. """
[文档] def __init__(self, **kwargs): super().__init__(**kwargs) self.temps = set() self.expr_handlers[Tmp] = self._handle_Tmp
# pylint:disable=unused-argument def _handle_Tmp(self, expr_idx: int, expr: Expression, stmt_idx: int, stmt: Statement, block) -> None: if isinstance(expr, Tmp): self.temps.add(expr)
[文档] class AILSimplifier(Analysis): """ Perform function-level simplifications. """
[文档] def __init__( self, func, func_graph=None, remove_dead_memdefs=False, stack_arg_offsets: set[tuple[int, int]] | None = None, unify_variables=False, ail_manager: Manager | None = None, gp: int | None = None, narrow_expressions=False, only_consts=False, fold_callexprs_into_conditions=False, use_callee_saved_regs_at_return=True, rewrite_ccalls=True, removed_vvar_ids: set[int] | None = None, arg_vvars: dict[int, tuple[VirtualVariable, SimVariable]] | None = None, avoid_vvar_ids: set[int] | None = None, ): self.func = func self.func_graph = func_graph if func_graph is not None else func.graph self._reaching_definitions: SRDAModel | None = None self._propagator = None self._remove_dead_memdefs = remove_dead_memdefs self._stack_arg_offsets = stack_arg_offsets self._unify_vars = unify_variables self._ail_manager: Manager | None = ail_manager self._gp = gp self._narrow_expressions = narrow_expressions self._only_consts = only_consts self._fold_callexprs_into_conditions = fold_callexprs_into_conditions self._use_callee_saved_regs_at_return = use_callee_saved_regs_at_return self._should_rewrite_ccalls = rewrite_ccalls self._removed_vvar_ids = removed_vvar_ids if removed_vvar_ids is not None else set() self._arg_vvars = arg_vvars self._avoid_vvar_ids = avoid_vvar_ids self._calls_to_remove: set[CodeLocation] = set() self._assignments_to_remove: set[CodeLocation] = set() self.blocks = {} # Mapping nodes to simplified blocks self.simplified: bool = False self._simplify()
def _simplify(self): if self._narrow_expressions: _l.debug("Removing dead assignments before narrowing expressions") r = self._remove_dead_assignments() if r: _l.debug("... dead assignments removed") self.simplified = True self._rebuild_func_graph() self._clear_cache() _l.debug("Narrowing expressions") narrowed_exprs = self._narrow_exprs() self.simplified |= narrowed_exprs if narrowed_exprs: _l.debug("... expressions narrowed") self._rebuild_func_graph() self._clear_cache() _l.debug("Folding expressions") folded_exprs = self._fold_exprs() self.simplified |= folded_exprs if folded_exprs: _l.debug("... expressions folded") self._rebuild_func_graph() # reaching definition analysis results are no longer reliable self._clear_cache() if self._only_consts: return if self._should_rewrite_ccalls: _l.debug("Rewriting ccalls") ccalls_rewritten = self._rewrite_ccalls() self.simplified |= ccalls_rewritten if ccalls_rewritten: _l.debug("... ccalls rewritten") self._rebuild_func_graph() self._clear_cache() if self._unify_vars: _l.debug("Removing dead assignments") r = self._remove_dead_assignments() if r: _l.debug("... dead assignments removed") self.simplified = True self._rebuild_func_graph() self._clear_cache() _l.debug("Unifying local variables") r = self._unify_local_variables() if r: _l.debug("... local variables unified") self.simplified = True self._rebuild_func_graph() # _fold_call_exprs() may set self._calls_to_remove, which will be honored in _remove_dead_assignments() _l.debug("Folding call expressions") r = self._fold_call_exprs() if r: _l.debug("... call expressions folded") self.simplified = True self._rebuild_func_graph() self._clear_cache() _l.debug("Removing dead assignments") r = self._remove_dead_assignments() if r: _l.debug("... dead assignments removed") self.simplified = True self._rebuild_func_graph() def _rebuild_func_graph(self): def _handler(node): return self.blocks.get(node, None) AILGraphWalker(self.func_graph, _handler, replace_nodes=True).walk() self.blocks = {} def _compute_reaching_definitions(self) -> SRDAModel: # Computing reaching definitions or return the cached one if self._reaching_definitions is not None: return self._reaching_definitions func_args = {vvar for vvar, _ in self._arg_vvars.values()} if self._arg_vvars else set() rd = self.project.analyses.SReachingDefinitions( subject=self.func, func_graph=self.func_graph, func_args=func_args, # use_callee_saved_regs_at_return=self._use_callee_saved_regs_at_return, # track_tmps=True, ).model self._reaching_definitions = rd return rd def _compute_propagation(self) -> SPropagatorAnalysis: # Propagate expressions or return the existing result if self._propagator is not None: return self._propagator func_args = {vvar for vvar, _ in self._arg_vvars.values()} if self._arg_vvars else set() prop = self.project.analyses[SPropagatorAnalysis].prep(fail_fast=self._fail_fast)( subject=self.func, func_graph=self.func_graph, func_args=func_args, # gp=self._gp, only_consts=self._only_consts, ) self._propagator = prop return prop def _compute_equivalence(self) -> set[Equivalence]: equivalence = set() for block in self.func_graph: for stmt_idx, stmt in enumerate(block.statements): if isinstance(stmt, Assignment): if isinstance(stmt.dst, VirtualVariable) and isinstance( stmt.src, (VirtualVariable, Tmp, Call, Convert) ): codeloc = CodeLocation(block.addr, stmt_idx, block_idx=block.idx, ins_addr=stmt.ins_addr) equivalence.add(Equivalence(codeloc, stmt.dst, stmt.src)) elif isinstance(stmt, Call): if isinstance(stmt.ret_expr, (VirtualVariable, Load)): codeloc = CodeLocation(block.addr, stmt_idx, block_idx=block.idx, ins_addr=stmt.ins_addr) equivalence.add(Equivalence(codeloc, stmt.ret_expr, stmt)) elif ( isinstance(stmt, Store) and isinstance(stmt.size, int) and isinstance(stmt.data, (VirtualVariable, Tmp, Call, Convert)) ): if isinstance(stmt.addr, StackBaseOffset) and isinstance(stmt.addr.offset, int): # stack variable atom = SimStackVariable(stmt.addr.offset, stmt.size) codeloc = CodeLocation(block.addr, stmt_idx, block_idx=block.idx, ins_addr=stmt.ins_addr) equivalence.add(Equivalence(codeloc, atom, stmt.data)) elif isinstance(stmt.addr, Const): # global variable atom = SimMemoryVariable(stmt.addr.value, stmt.size) codeloc = CodeLocation(block.addr, stmt_idx, block_idx=block.idx, ins_addr=stmt.ins_addr) equivalence.add(Equivalence(codeloc, atom, stmt.data)) return equivalence def _clear_cache(self) -> None: self._propagator = None self._reaching_definitions = None def _clear_propagator_cache(self) -> None: self._propagator = None def _clear_reaching_definitions_cache(self) -> None: self._reaching_definitions = None # # Expression narrowing # def _narrow_exprs(self) -> bool: """ A register may be used with full width even when only the lower bytes are really needed. This results in the incorrect determination of wider variables while the actual variable is narrower (e.g., int64 vs char). This optimization narrows a register definition if all its uses are narrower than the definition itself. """ narrowed = False addr_and_idx_to_block: dict[tuple[int, int], Block] = {} for block in self.func_graph.nodes(): addr_and_idx_to_block[(block.addr, block.idx)] = block rd = self._compute_reaching_definitions() sorted_defs = sorted(rd.all_definitions, key=lambda d: d.codeloc, reverse=True) narrowing_candidates: dict[int, tuple[Definition, ExprNarrowingInfo]] = {} for def_ in (d_ for d_ in sorted_defs if d_.codeloc.context is None): if isinstance(def_.atom, atoms.VirtualVariable) and (def_.atom.was_reg or def_.atom.was_parameter): # only do this for general purpose register skip_def = False for reg in self.project.arch.register_list: if not reg.artificial and reg.vex_offset == def_.atom.reg_offset and not reg.general_purpose: skip_def = True break if skip_def: continue narrow = self._narrowing_needed(def_, rd, addr_and_idx_to_block) if narrow.narrowable: # we cannot narrow it immediately because any definition that is used by phi variables must be # narrowed together with all other definitions that can reach the phi variables. # so we record the information and decide if we are going to narrow these expressions or not at the # end of the loop. narrowing_candidates[def_.atom.varid] = def_, narrow # first, determine which phi vars need to be narrowed and can be narrowed. # a phi var can only be narrowed if all its source vvars are narrowable vvar_to_narrowing_size = {} for def_varid, (_, narrow_info) in narrowing_candidates.items(): vvar_to_narrowing_size[def_varid] = narrow_info.to_size blacklist_varids = set() while True: repeat, narrowables = self._compute_narrowables_once( rd, narrowing_candidates, vvar_to_narrowing_size, blacklist_varids ) if not repeat: break # let's narrow them (finally) narrower = ExpressionNarrower(self.project, rd, narrowables, addr_and_idx_to_block, self.blocks) for old_block in addr_and_idx_to_block.values(): new_block = self.blocks.get(old_block, old_block) new_block = narrower.walk(new_block) if narrower.narrowed_any: narrowed = True self.blocks[old_block] = new_block # update self._arg_vvars if necessary for new_vvars in narrower.replacement_core_vvars.values(): for new_vvar in new_vvars: if new_vvar.was_parameter and self._arg_vvars: for func_arg_idx in list(self._arg_vvars): vvar, simvar = self._arg_vvars[func_arg_idx] if vvar.varid == new_vvar.varid: simvar_new = simvar.copy() simvar_new._hash = None simvar_new.size = new_vvar.size self._arg_vvars[func_arg_idx] = new_vvar, simvar_new return narrowed @staticmethod def _compute_narrowables_once( rd, narrowing_candidates: dict, vvar_to_narrowing_size: dict[int, int], blacklist_varids: set ): repeat = False narrowable_phivarids = set() for def_vvarid in narrowing_candidates: if def_vvarid in blacklist_varids: continue if def_vvarid in rd.phi_vvar_ids: narrowing_sizes = set() src_vvarids = rd.phivarid_to_varids[def_vvarid] for vvarid in src_vvarids: if vvarid in blacklist_varids: narrowing_sizes.add(None) else: narrowing_sizes.add(vvar_to_narrowing_size.get(vvarid)) if len(narrowing_sizes) == 1 and None not in narrowing_sizes: # we can narrow this phi vvar! narrowable_phivarids.add(def_vvarid) else: # blacklist it for now blacklist_varids.add(def_vvarid) # now determine what to narrow! narrowables = [] for def_, narrow_info in narrowing_candidates.values(): if def_.atom.varid in blacklist_varids: continue if not narrow_info.phi_vars: # not used by any other phi variables. good! narrowables.append((def_, narrow_info)) else: if {phivar.varid for phivar in narrow_info.phi_vars}.issubset(narrowable_phivarids): # all phi vvars that use this definition can be narrowed narrowables.append((def_, narrow_info)) else: # this vvar cannot be narrowed # note that all phi variables that relies on this vvar also cannot be narrowed! we must analyze # again repeat = True blacklist_varids.add(def_.atom.varid) blacklist_varids |= {phivar.varid for phivar in narrow_info.phi_vars} return repeat, narrowables def _narrowing_needed(self, def_, rd: SRDAModel, addr_and_idx_to_block) -> ExprNarrowingInfo: def_size = def_.size # find its uses # some use locations are phi assignments. we keep tracking the uses of phi variables and update the dictionary result = self._get_vvar_use_and_exprs_recursive(def_.atom, rd, addr_and_idx_to_block) if result is None: return ExprNarrowingInfo(False) use_and_exprs, phi_vars = result all_used_sizes = set() used_by: list[tuple[atoms.VirtualVariable, CodeLocation, tuple[str, tuple[Expression, ...]]]] = [] used_by_loc = defaultdict(list) for atom, loc, expr in use_and_exprs: old_block = addr_and_idx_to_block.get((loc.block_addr, loc.block_idx), None) if old_block is None: # missing a block for whatever reason return ExprNarrowingInfo(False) block = self.blocks.get(old_block, old_block) if loc.stmt_idx >= len(block.statements): # missing a statement for whatever reason return ExprNarrowingInfo(False) stmt = block.statements[loc.stmt_idx] # special case: if the statement is a Call statement and expr is None, it means we have not been able to # determine if the expression is really used by the call or not. skip it in this case if isinstance(stmt, Call) and expr is None: continue # special case: if the statement is a phi statement, we ignore it if is_phi_assignment(stmt): continue expr_size, used_by_exprs = self._extract_expression_effective_size(stmt, expr) if expr_size is None: # it's probably used in full width return ExprNarrowingInfo(False) all_used_sizes.add(expr_size) used_by_loc[loc].append((atom, used_by_exprs)) if len(all_used_sizes) == 1 and next(iter(all_used_sizes)) < def_size: for loc, atom_expr_pairs in used_by_loc.items(): if len(atom_expr_pairs) == 1: atom, used_by_exprs = atom_expr_pairs[0] used_by.append((atom, loc, used_by_exprs)) else: # the order matters - we must replace the outer expressions first, then replace the inner # expressions. replacing in the wrong order will lead to expressions that are not replaced in the # end. ordered = [] for atom, used_by_exprs in atom_expr_pairs: last_inclusion = len(ordered) - 1 # by default we append at the end of the list for idx in range(len(ordered)): if self._is_expr0_included_in_expr1(ordered[idx][1], used_by_exprs): # this element must be inserted before idx ordered.insert(idx, (atom, used_by_exprs)) break if self._is_expr0_included_in_expr1(used_by_exprs, ordered[idx][1]): # this element can be inserted after this element. record the index last_inclusion = idx else: ordered.insert(last_inclusion + 1, (atom, used_by_exprs)) for atom, used_by_exprs in ordered: used_by.append((atom, loc, used_by_exprs)) return ExprNarrowingInfo(True, to_size=next(iter(all_used_sizes)), use_exprs=used_by, phi_vars=phi_vars) return ExprNarrowingInfo(False) @staticmethod def _exprs_from_used_by_exprs(used_by_exprs) -> set[Expression]: use_type, expr_tuple = used_by_exprs match use_type: case "expr" | "mask" | "convert": return {expr_tuple[1]} if len(expr_tuple) == 2 else {expr_tuple[0]} case "phi-src-expr": return {expr_tuple[0]} case "binop-convert": return {expr_tuple[0], expr_tuple[1]} case _: return set() def _is_expr0_included_in_expr1(self, used_by_exprs0, used_by_exprs1) -> bool: # extract expressions exprs0 = self._exprs_from_used_by_exprs(used_by_exprs0) exprs1 = self._exprs_from_used_by_exprs(used_by_exprs1) # test for inclusion for expr1 in exprs1: walker = HasExprWalker(exprs0) walker.walk_expression(expr1) if walker.contains_exprs: return True return False def _get_vvar_use_and_exprs_recursive( self, initial_atom: atoms.VirtualVariable, rd, block_dict: dict[tuple[int, int | None], Block] ) -> tuple[list[tuple[atoms.VirtualVariable, CodeLocation, Expression]], set[VirtualVariable]] | None: result = [] atom_queue = [initial_atom] phi_vars = set() seen = set() while atom_queue: atom = atom_queue.pop(0) seen.add(atom) use_and_exprs = rd.get_vvar_uses_with_expr(atom) for loc, expr in use_and_exprs: old_block = block_dict.get((loc.block_addr, loc.block_idx), None) if old_block is None: # missing a block for whatever reason return None block: Block = self.blocks.get(old_block, old_block) if loc.stmt_idx >= len(block.statements): # missing a statement for whatever reason return None stmt = block.statements[loc.stmt_idx] if is_phi_assignment(stmt): phi_vars.add(stmt.dst) new_atom = atoms.VirtualVariable( stmt.dst.varid, stmt.dst.size, stmt.dst.category, oident=stmt.dst.oident ) if new_atom not in seen: atom_queue.append(new_atom) else: result.append((atom, loc, expr)) return result, phi_vars def _extract_expression_effective_size( self, statement, expr ) -> tuple[int | None, tuple[str, tuple[Expression, ...]] | None]: """ Determine the effective size of an expression when it's used. """ walker = NarrowingInfoExtractor(expr) walker.walk_statement(statement) if not walker.operations: if expr is None: return None, None return expr.size, ("expr", (expr,)) first_op = walker.operations[0] if isinstance(first_op, Convert) and first_op.to_bits >= self.project.arch.byte_width: # we need at least one byte! return first_op.to_bits // self.project.arch.byte_width, ("convert", (first_op,)) if isinstance(first_op, BinaryOp): second_op = None if len(walker.operations) >= 2: second_op = walker.operations[1] if ( first_op.op == "And" and isinstance(first_op.operands[1], Const) and ( second_op is None or (isinstance(second_op, BinaryOp) and isinstance(second_op.operands[1], Const)) ) ): mask = first_op.operands[1].value if mask == 0xFF: return 1, ("mask", (first_op, second_op)) if second_op is not None else ("mask", (first_op,)) if mask == 0xFFFF: return 2, ("mask", (first_op, second_op)) if second_op is not None else ("mask", (first_op,)) if mask == 0xFFFF_FFFF: return 4, ("mask", (first_op, second_op)) if second_op is not None else ("mask", (first_op,)) if ( (first_op.operands[0] is expr or first_op.operands[1] is expr) and first_op.op not in {"Shr", "Sar"} and isinstance(second_op, Convert) and second_op.from_bits == expr.bits and second_op.to_bits >= self.project.arch.byte_width # we need at least one byte! ): return min(expr.bits, second_op.to_bits) // self.project.arch.byte_width, ( "binop-convert", (expr, first_op, second_op), ) if expr is None: return None, None return expr.size, ("expr", (expr,)) # # Expression folding # def _fold_exprs(self): """ Fold expressions: Fold assigned expressions that are constant or only used once. """ # propagator propagator = self._compute_propagation() replacements = propagator.replacements # take replacements and rebuild the corresponding blocks replacements_by_block_addrs_and_idx = defaultdict(dict) for codeloc, reps in replacements.items(): if reps: replacements_by_block_addrs_and_idx[(codeloc.block_addr, codeloc.block_idx)][codeloc] = reps if not replacements_by_block_addrs_and_idx: return False blocks_by_addr_and_idx = {(node.addr, node.idx): node for node in self.func_graph.nodes()} if self._stack_arg_offsets: insn_addrs_using_stack_args = {ins_addr for ins_addr, _ in self._stack_arg_offsets} else: insn_addrs_using_stack_args = None replaced = False for (block_addr, block_idx), reps in replacements_by_block_addrs_and_idx.items(): block = blocks_by_addr_and_idx[(block_addr, block_idx)] # only replace loads if there are stack arguments in this block replace_loads = insn_addrs_using_stack_args is not None and { stmt.ins_addr for stmt in block.statements }.intersection(insn_addrs_using_stack_args) # remove virtual variables in the avoid list if self._avoid_vvar_ids: filtered_reps = {} for loc, rep_dict in reps.items(): filtered_reps[loc] = { k: v for k, v in rep_dict.items() if not (isinstance(k, VirtualVariable) and k.varid in self._avoid_vvar_ids) } reps = filtered_reps r, new_block = BlockSimplifier._replace_and_build(block, reps, gp=self._gp, replace_loads=replace_loads) replaced |= r self.blocks[block] = new_block if replaced: # blocks have been rebuilt - expression propagation results are no longer reliable self._clear_cache() return replaced # # Unifying local variables # def _unify_local_variables(self) -> bool: """ Find variables that are definitely equivalent and then eliminate unnecessary copies. """ simplified = False equivalence = self._compute_equivalence() if not equivalence: return simplified addr_and_idx_to_block: dict[tuple[int, int], Block] = {} for block in self.func_graph.nodes(): addr_and_idx_to_block[(block.addr, block.idx)] = block equivalences: dict[Any, set[Equivalence]] = defaultdict(set) atom_by_loc = set() for eq in equivalence: equivalences[eq.atom1].add(eq) atom_by_loc.add((eq.codeloc, eq.atom1)) # sort keys to ensure a reproducible result sorted_loc_and_atoms = sorted(atom_by_loc, key=lambda x: x[0]) for _, atom in sorted_loc_and_atoms: eqs = equivalences[atom] if len(eqs) > 1: continue eq = next(iter(eqs)) # Acceptable equivalence classes: # # stack variable == register # register variable == register # stack variable == Conv(register, M->N) # global variable == register # # Equivalence is generally created at assignment sites. Therefore, eq.atom0 is the definition and # eq.atom1 is the use. the_def = None if (isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_stack) or ( isinstance(eq.atom0, SimMemoryVariable) and not isinstance(eq.atom0, SimStackVariable) and isinstance(eq.atom0.addr, int) ): if isinstance(eq.atom1, VirtualVariable) and eq.atom1.was_reg: # stack_var == register or global_var == register to_replace = eq.atom1 to_replace_is_def = False elif ( isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_stack and isinstance(eq.atom1, VirtualVariable) and eq.atom1.was_parameter ): # stack_var == parameter to_replace = eq.atom0 to_replace_is_def = True elif ( isinstance(eq.atom1, Convert) and isinstance(eq.atom1.operand, VirtualVariable) and eq.atom1.operand.was_reg ): # stack_var == Conv(register, M->N) to_replace = eq.atom1.operand to_replace_is_def = False else: continue elif isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_reg: if isinstance(eq.atom1, VirtualVariable) and (eq.atom1.was_reg or eq.atom1.was_parameter): # register == register if self.project.arch.is_artificial_register(eq.atom0.reg_offset, eq.atom0.size): to_replace = eq.atom0 to_replace_is_def = True else: to_replace = eq.atom1 to_replace_is_def = False else: continue else: continue assert isinstance(to_replace, VirtualVariable) # find the definition of this virtual register rd = self._compute_reaching_definitions() if to_replace_is_def: # find defs defs = [] for def_ in rd.all_definitions: if def_.atom.varid == to_replace.varid: defs.append(def_) if len(defs) != 1: continue the_def = defs[0] else: # find uses defs = rd.get_uses_by_location(eq.codeloc) if len(defs) != 1: # there are multiple defs for this register - we do not support replacing all of them continue for def_ in defs: def_: Definition if ( isinstance(def_.atom, atoms.VirtualVariable) and def_.atom.category == to_replace.category and def_.atom.oident == to_replace.oident ): # found it! the_def = def_ break if the_def is None: continue if the_def.codeloc.context: # FIXME: now the_def.codeloc.context is never filled in # the definition is in a callee function continue if isinstance(the_def.codeloc, ExternalCodeLocation) or ( isinstance(eq.atom1, VirtualVariable) and eq.atom1.was_parameter ): # this is a function argument. we enter a slightly different logic and try to eliminate copies of this # argument if # (a) the on-stack or in-register copy of it has never been modified in this function # (b) the function argument register has never been updated. # TODO: we may loosen requirement (b) once we have real register versioning in AIL. defs = [def_ for def_ in rd.all_definitions if def_.codeloc == eq.codeloc] all_uses_with_def = None replace_with = None remove_initial_assignment = None if defs and len(defs) == 1: arg_copy_def = defs[0] if (isinstance(arg_copy_def.atom, atoms.VirtualVariable) and arg_copy_def.atom.was_stack) or ( isinstance(arg_copy_def.atom, atoms.VirtualVariable) and arg_copy_def.atom.was_reg ): # found the copied definition (either a stack variable or a register variable) # Make sure there is no other write to this stack location if the copy is a stack variable if ( isinstance(arg_copy_def.atom, atoms.VirtualVariable) and arg_copy_def.atom.was_stack and any( (def_ != arg_copy_def and def_.atom.stack_offset == arg_copy_def.atom.stack_offset) for def_ in rd.all_definitions if isinstance(def_.atom, atoms.VirtualVariable) and def_.atom.was_stack ) ): continue # Make sure the register is never updated across this function if any( (def_ != the_def and def_.atom == the_def.atom) for def_ in rd.all_definitions if isinstance(def_.atom, atoms.VirtualVariable) and def_.atom.was_reg and rd.get_vvar_uses(def_.atom) ): continue # find all its uses all_arg_copy_var_uses: set[tuple[CodeLocation, Any]] = set( rd.get_vvar_uses_with_expr(arg_copy_def.atom) ) all_uses_with_def = set() should_abort = False for use in all_arg_copy_var_uses: used_expr = use[1] if used_expr is not None and used_expr.size != arg_copy_def.size: should_abort = True break all_uses_with_def.add((arg_copy_def, use)) if should_abort: continue replace_with = eq.atom1 remove_initial_assignment = True if all_uses_with_def is None: continue else: if ( eq.codeloc.block_addr == the_def.codeloc.block_addr and eq.codeloc.block_idx == the_def.codeloc.block_idx ): # the definition and the eq location are within the same block, and the definition is before # the eq location. if eq.codeloc.stmt_idx < the_def.codeloc.stmt_idx: continue else: # the definition is in the predecessor block of the eq eq_block = next( iter( bb for bb in self.func_graph if bb.addr == eq.codeloc.block_addr and bb.idx == eq.codeloc.block_idx ) ) eq_block_preds = set(self.func_graph.predecessors(eq_block)) if not any( pred.addr == the_def.codeloc.block_addr and pred.idx == the_def.codeloc.block_idx for pred in eq_block_preds ): continue if isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_stack: # create the replacement expression if isinstance(eq.atom1, VirtualVariable) and eq.atom1.was_parameter: # replacing atom0 new_idx = None if self._ail_manager is None else next(self._ail_manager.atom_ctr) replace_with = VirtualVariable( new_idx, eq.atom1.varid, eq.atom1.bits, category=eq.atom1.category, oident=eq.atom1.oident, **eq.atom1.tags, ) else: # replacing atom1 new_idx = None if self._ail_manager is None else next(self._ail_manager.atom_ctr) replace_with = VirtualVariable( new_idx, eq.atom0.varid, eq.atom0.bits, category=eq.atom0.category, oident=eq.atom0.oident, **eq.atom0.tags, ) elif isinstance(eq.atom0, SimMemoryVariable) and isinstance(eq.atom0.addr, int): # create the memory loading expression new_idx = None if self._ail_manager is None else next(self._ail_manager.atom_ctr) replace_with = Load( new_idx, Const(None, None, eq.atom0.addr, self.project.arch.bits), eq.atom0.size, endness=self.project.arch.memory_endness, **eq.atom1.tags, ) elif isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_reg: if isinstance(eq.atom1, VirtualVariable) and eq.atom1.was_reg: if self.project.arch.is_artificial_register(eq.atom0.reg_offset, eq.atom0.size): replace_with = eq.atom1 else: replace_with = eq.atom0 else: raise AngrRuntimeError(f"Unsupported atom1 type {type(eq.atom1)}.") else: raise AngrRuntimeError(f"Unsupported atom0 type {type(eq.atom0)}.") to_replace_def = the_def # check: the definition of expression being replaced should not be a phi variable if ( isinstance(to_replace_def.atom, atoms.VirtualVariable) and to_replace_def.atom.varid in rd.phi_vvar_ids ): continue # find all uses of this definition # we make a copy of the set since we may touch the set (uses) when replacing expressions all_uses: set[tuple[CodeLocation, Any]] = set(rd.get_vvar_uses_with_expr(to_replace_def.atom)) # make sure none of these uses are phi nodes (depends on more than one def) all_uses_with_unique_def = set() for use_and_expr in all_uses: use_loc, used_expr = use_and_expr defs_and_exprs = rd.get_uses_by_location(use_loc, exprs=True) filtered_defs = {def_ for def_, expr_ in defs_and_exprs if expr_ == used_expr} if len(filtered_defs) == 1: all_uses_with_unique_def.add(use_and_expr) else: # optimization: break early break if len(all_uses) != len(all_uses_with_unique_def): # only when all uses are determined by the same definition will we continue with the simplification continue # one more check: there can be at most one assignment in all these use locations if the expression is # not going to be replaced with a parameter. the assignment can be an Assignment statement, but may also # be a Store if it's a global variable (via Load) that we are replacing with if not (isinstance(replace_with, VirtualVariable) and replace_with.was_parameter): assignment_ctr = 0 all_use_locs = {use_loc for use_loc, _ in all_uses} for use_loc in all_use_locs: if use_loc == eq.codeloc: continue block = addr_and_idx_to_block[(use_loc.block_addr, use_loc.block_idx)] stmt = block.statements[use_loc.stmt_idx] if isinstance(stmt, Assignment) or (isinstance(replace_with, Load) and isinstance(stmt, Store)): assignment_ctr += 1 if assignment_ctr > 1: continue all_uses_with_def = {(to_replace_def, use_and_expr) for use_and_expr in all_uses} remove_initial_assignment = False # expression folding will take care of it if any(not isinstance(use_and_expr[1], VirtualVariable) for _, use_and_expr in all_uses_with_def): # if any of the uses are phi assignments, we skip used_in_phi_assignment = False for _, use_and_expr in all_uses_with_def: u = use_and_expr[0] block = addr_and_idx_to_block[(u.block_addr, u.block_idx)] stmt = block.statements[u.stmt_idx] if is_phi_assignment(stmt): used_in_phi_assignment = True break if used_in_phi_assignment: continue # ensure the uses we consider are all after the eq location filtered_all_uses_with_def = [] for def_, use_and_expr in all_uses_with_def: u = use_and_expr[0] if ( u.block_addr == eq.codeloc.block_addr and u.block_idx == eq.codeloc.block_idx and u.stmt_idx < eq.codeloc.stmt_idx ): # this use happens before the assignment - ignore it continue filtered_all_uses_with_def.append((def_, use_and_expr)) all_uses_with_def = filtered_all_uses_with_def if not all_uses_with_def: # definitions without uses may simply be our data-flow analysis being incorrect. do not remove them. continue # TODO: We can only replace all these uses with the stack variable if the stack variable isn't # TODO: re-assigned of a new value. Perform this check. # replace all uses all_uses_replaced = True for def_, use_and_expr in all_uses_with_def: u, used_expr = use_and_expr use_expr_defns = [] for d in rd.get_uses_by_location(u): if ( isinstance(d.atom, atoms.VirtualVariable) and d.atom.was_reg and isinstance(def_.atom, atoms.VirtualVariable) and def_.atom.was_reg and d.atom.reg_offset == def_.atom.reg_offset ) or d.atom == def_.atom: use_expr_defns.append(d) # you can never replace a use with dependencies from outside the checked defn if len(use_expr_defns) != 1 or next(iter(use_expr_defns)) != def_: if not use_expr_defns: _l.warning("There was no use_expr_defns for %s, this is likely a bug", u) # TODO: can you have multiple definitions which can all be eliminated? all_uses_replaced = False continue if u == eq.codeloc: # skip the very initial assignment location continue old_block = addr_and_idx_to_block.get((u.block_addr, u.block_idx), None) if old_block is None: continue if used_expr is None: all_uses_replaced = False continue # ensure the expression that we want to replace with is still up-to-date replace_with_original_def = self._find_atom_def_at(replace_with, rd, def_.codeloc) if replace_with_original_def is not None and not self._check_atom_last_def( replace_with, u, rd, replace_with_original_def ): all_uses_replaced = False continue # if there is an updated block, use it the_block = self.blocks.get(old_block, old_block) stmt: Statement = the_block.statements[u.stmt_idx] replace_with_copy = replace_with.copy() if used_expr.size != replace_with_copy.size: new_idx = None if self._ail_manager is None else next(self._ail_manager.atom_ctr) replace_with_copy = Convert( new_idx, replace_with_copy.bits, used_expr.bits, False, replace_with_copy, ) r, new_block = self._replace_expr_and_update_block( the_block, u.stmt_idx, stmt, used_expr, replace_with_copy ) if r: self.blocks[old_block] = new_block else: # failed to replace a use - we need to keep the initial assignment! all_uses_replaced = False simplified |= r if all_uses_replaced and remove_initial_assignment: # the initial statement can be removed self._assignments_to_remove.add(eq.codeloc) if simplified: self._clear_cache() return simplified @staticmethod def _find_atom_def_at(atom, rd, codeloc: CodeLocation) -> Definition | None: if isinstance(atom, Register): defs = rd.get_defs(atom, codeloc, OP_BEFORE) return next(iter(defs)) if len(defs) == 1 else None return None @staticmethod def _check_atom_last_def(atom, codeloc, rd, the_def) -> bool: if isinstance(atom, Register): defs = rd.get_defs(atom, codeloc, OP_BEFORE) for d in defs: if d.codeloc != the_def.codeloc: return False return True # # Folding call expressions # @staticmethod def _is_expr_using_temporaries(expr: Expression) -> bool: walker = AILBlockTempCollector() walker.walk_expression(expr) return len(walker.temps) > 0 @staticmethod def _is_stmt_using_temporaries(stmt: Statement) -> bool: walker = AILBlockTempCollector() walker.walk_statement(stmt) return len(walker.temps) > 0 def _fold_call_exprs(self) -> bool: """ Fold a call expression (statement) into other statements if the return value of the call expression (statement) is only used once, and the use site and the call site belongs to the same supernode. Example:: s1 = func(); s0 = s1; if (s0) ... after folding, it will be transformed to:: s0 = func(); if (s0) ... s0 can be folded into the condition, which means this example can further be transformed to:: if (func()) ... this behavior is controlled by fold_callexprs_into_conditions. This to avoid cases where func() is called more than once after simplification and graph structuring where conditions might be duplicated (e.g., in Dream). In such cases, the one-use expression folder in RegionSimplifier will perform this transformation. """ simplified = False equivalence = self._compute_equivalence() if not equivalence: return simplified addr_and_idx_to_block: dict[tuple[int, int], Block] = {} for block in self.func_graph.nodes(): addr_and_idx_to_block[(block.addr, block.idx)] = block def_locations_to_remove: set[CodeLocation] = set() updated_use_locations: set[CodeLocation] = set() for eq in equivalence: # register variable == Call if isinstance(eq.atom0, VirtualVariable) and eq.atom0.was_reg: if isinstance(eq.atom1, Call): # register variable = Call call: Expression = eq.atom1 # call_addr = call.target.value if isinstance(call.target, Const) else None elif isinstance(eq.atom1, Convert) and isinstance(eq.atom1.operand, Call): # register variable = Convert(Call) call = eq.atom1 # call_addr = call.operand.target.value if isinstance(call.operand.target, Const) else None else: continue if self._is_expr_using_temporaries(call): continue if eq.codeloc in updated_use_locations: # this def is now created by an updated use. the corresponding statement will be updated in the end. # we must rerun Propagator to get an updated definition (and Equivalence) continue # find all uses of this virtual register rd = self._compute_reaching_definitions() the_def: Definition = Definition( atoms.VirtualVariable( eq.atom0.varid, eq.atom0.size, category=eq.atom0.category, oident=eq.atom0.oident ), eq.codeloc, ) all_uses: set[tuple[CodeLocation, Any]] = set(rd.get_vvar_uses_with_expr(the_def.atom)) if len(all_uses) != 1: continue u, used_expr = next(iter(all_uses)) if used_expr is None: continue if u in def_locations_to_remove: # this use site has been altered by previous folding attempts. the corresponding statement will be # removed in the end. in this case, this Equivalence is probably useless, and we must rerun # Propagator to get an updated Equivalence. continue if not self._fold_callexprs_into_conditions: # check the statement and make sure it's not a conditional jump the_block = addr_and_idx_to_block[(u.block_addr, u.block_idx)] if isinstance(the_block.statements[u.stmt_idx], ConditionalJump): continue # check if the use and the definition is within the same supernode super_node_blocks = self._get_super_node_blocks( addr_and_idx_to_block[(the_def.codeloc.block_addr, the_def.codeloc.block_idx)] ) if u.block_addr not in {b.addr for b in super_node_blocks}: continue # check if the register has been overwritten by statements in between the def site and the use site # usesite_atom_defs = set(rd.get_defs(the_def.atom, u, OP_BEFORE)) # if len(usesite_atom_defs) != 1: # continue # usesite_atom_def = next(iter(usesite_atom_defs)) # if usesite_atom_def != the_def: # continue # check if any atoms that the call relies on has been overwritten by statements in between the def site # and the use site. # TODO: Prove non-interference # defsite_all_expr_uses = set(rd.all_uses.get_uses_by_location(the_def.codeloc)) # defsite_used_atoms = set() # for dd in defsite_all_expr_uses: # defsite_used_atoms.add(dd.atom) # usesite_expr_def_outdated = False # for defsite_expr_atom in defsite_used_atoms: # usesite_expr_uses = set(rd.get_defs(defsite_expr_atom, u, OP_BEFORE)) # if not usesite_expr_uses: # # the atom is not defined at the use site - it's fine # continue # defsite_expr_uses = set(rd.get_defs(defsite_expr_atom, the_def.codeloc, OP_BEFORE)) # if usesite_expr_uses != defsite_expr_uses: # # special case: ok if this atom is assigned to at the def site and has not been overwritten # if len(usesite_expr_uses) == 1: # usesite_expr_use = next(iter(usesite_expr_uses)) # if usesite_expr_use.atom == defsite_expr_atom and ( # usesite_expr_use.codeloc == the_def.codeloc # or usesite_expr_use.codeloc.block_addr == call_addr # ): # continue # usesite_expr_def_outdated = True # break # if usesite_expr_def_outdated: # continue # check if there are any calls in between the def site and the use site if self._count_calls_in_supernodeblocks(super_node_blocks, the_def.codeloc, u) > 0: continue # replace all uses old_block = addr_and_idx_to_block.get((u.block_addr, u.block_idx), None) if old_block is None: continue # if there is an updated block, use that the_block = self.blocks.get(old_block, old_block) stmt: Statement = the_block.statements[u.stmt_idx] if isinstance(eq.atom0, VirtualVariable): src = used_expr dst: Call | Convert = call.copy() if isinstance(dst, Call) and dst.ret_expr is not None: dst_bits = dst.ret_expr.bits # clear the ret_expr and fp_ret_expr of dst, then set bits so that it can be used as an # expression dst.ret_expr = None dst.fp_ret_expr = None dst.bits = dst_bits if src.bits != dst.bits: dst = Convert(None, dst.bits, src.bits, False, dst) else: continue # ensure what we are going to replace only appears once expr_ctr = SingleExpressionCounter(stmt, src) if expr_ctr.count > 1: continue replaced, new_block = self._replace_expr_and_update_block(the_block, u.stmt_idx, stmt, src, dst) if replaced: self.blocks[old_block] = new_block # this call has been folded to the use site. we can remove this call. self._calls_to_remove.add(eq.codeloc) simplified = True def_locations_to_remove.add(eq.codeloc) updated_use_locations.add(u) # no need to clear the cache at the end of this method return simplified def _get_super_node_blocks(self, start_node: Block) -> list[Block]: lst: list[Block] = [start_node] while True: b = lst[-1] successors = list(self.func_graph.successors(b)) if len(successors) == 0: break if len(successors) == 1: succ = successors[0] # check its predecessors succ_predecessors = list(self.func_graph.predecessors(succ)) if len(succ_predecessors) == 1: lst.append(succ) else: break else: # too many successors break return lst @staticmethod def _replace_expr_and_update_block(block, stmt_idx, stmt, src_expr, dst_expr) -> tuple[bool, Block | None]: replaced, new_stmt = stmt.replace(src_expr, dst_expr) if replaced: new_block = block.copy() new_block.statements = block.statements[::] new_block.statements[stmt_idx] = new_stmt return True, new_block return False, None def _remove_dead_assignments(self) -> bool: # keeping tracking of statements to remove and statements (as well as dead vvars) to keep allows us to handle # cases where a statement defines more than one atoms, e.g., a call statement that defines both the return # value and the floating-point return value. stmts_to_remove_per_block: dict[tuple[int, int], set[int]] = defaultdict(set) stmts_to_keep_per_block: dict[tuple[int, int], set[int]] = defaultdict(set) dead_vvar_ids: set[int] = set() # Find all statements that should be removed mask = (1 << self.project.arch.bits) - 1 rd = self._compute_reaching_definitions() stackarg_offsets = ( {(tpl[1] & mask) for tpl in self._stack_arg_offsets} if self._stack_arg_offsets is not None else None ) for def_ in rd.all_definitions: if def_.dummy: continue # we do not remove references to global memory regions no matter what if isinstance(def_.atom, atoms.MemoryLocation) and isinstance(def_.atom.addr, int): continue if isinstance(def_.atom, atoms.VirtualVariable): if def_.atom.was_stack: if not self._remove_dead_memdefs: if rd.is_phi_vvar_id(def_.atom.varid): # we always remove unused phi variables pass elif stackarg_offsets is not None: # we always remove definitions for stack arguments if (def_.atom.stack_offset & mask) not in stackarg_offsets: continue else: continue uses = rd.get_vvar_uses(def_.atom) elif def_.atom.was_tmp or def_.atom.was_reg or def_.atom.was_parameter: uses = rd.get_vvar_uses(def_.atom) else: uses = set() else: continue if not uses: if isinstance(def_.atom, atoms.VirtualVariable): dead_vvar_ids.add(def_.atom.varid) if not isinstance(def_.codeloc, ExternalCodeLocation): stmts_to_remove_per_block[(def_.codeloc.block_addr, def_.codeloc.block_idx)].add( def_.codeloc.stmt_idx ) else: stmts_to_keep_per_block[(def_.codeloc.block_addr, def_.codeloc.block_idx)].add(def_.codeloc.stmt_idx) # find all phi variables that rely on variables that no longer exist all_removed_var_ids = self._removed_vvar_ids.copy() removed_vvar_ids = self._removed_vvar_ids while True: new_removed_vvar_ids = set() for phi_varid, phi_use_varids in rd.phivarid_to_varids.items(): if phi_varid not in all_removed_var_ids and any( vvarid in removed_vvar_ids for vvarid in phi_use_varids ): loc = rd.all_vvar_definitions[rd.varid_to_vvar[phi_varid]] stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx) new_removed_vvar_ids.add(phi_varid) all_removed_var_ids.add(phi_varid) if not new_removed_vvar_ids: break removed_vvar_ids = new_removed_vvar_ids # find all phi variables that are only ever used by other phi variables redundant_phi_and_dirty_varids = self._find_cyclic_dependent_phis_and_dirty_vvars(rd) for varid in redundant_phi_and_dirty_varids: loc = rd.all_vvar_definitions[rd.varid_to_vvar[varid]] stmts_to_remove_per_block[(loc.block_addr, loc.block_idx)].add(loc.stmt_idx) stmts_to_keep_per_block[(loc.block_addr, loc.block_idx)].discard(loc.stmt_idx) for codeloc in self._calls_to_remove | self._assignments_to_remove: # this call can be removed. make sure it exists in stmts_to_remove_per_block stmts_to_remove_per_block[codeloc.block_addr, codeloc.block_idx].add(codeloc.stmt_idx) simplified = False # Remove the statements for old_block in self.func_graph.nodes(): # if there is an updated block, use it block = self.blocks.get(old_block, old_block) if not isinstance(block, Block): continue if (block.addr, block.idx) not in stmts_to_remove_per_block: continue new_statements = [] stmts_to_remove = stmts_to_remove_per_block[(block.addr, block.idx)] stmts_to_keep = stmts_to_keep_per_block[(block.addr, block.idx)] if not stmts_to_remove: continue for idx, stmt in enumerate(block.statements): if idx in stmts_to_remove and idx in stmts_to_keep and isinstance(stmt, Call): # this statement declares more than one variable. we should handle it surgically # case 1: stmt.ret_expr and stmt.fp_ret_expr are both set, but one of them is not used if isinstance(stmt.ret_expr, VirtualVariable) and stmt.ret_expr.varid in dead_vvar_ids: stmt = stmt.copy() stmt.ret_expr = None simplified = True if isinstance(stmt.fp_ret_expr, VirtualVariable) and stmt.fp_ret_expr.varid in dead_vvar_ids: stmt = stmt.copy() stmt.fp_ret_expr = None simplified = True if idx in stmts_to_remove and idx not in stmts_to_keep and not isinstance(stmt, DirtyStatement): if isinstance(stmt, (Assignment, Store)): # Special logic for Assignment and Store statements # if this statement triggers a call, it should only be removed if it's in self._calls_to_remove codeloc = CodeLocation(block.addr, idx, ins_addr=stmt.ins_addr, block_idx=block.idx) if codeloc in self._assignments_to_remove: # it should be removed simplified = True continue if self._statement_has_call_exprs(stmt): if codeloc in self._calls_to_remove: # it has a call and must be removed simplified = True continue if isinstance(stmt, Assignment) and isinstance(stmt.dst, VirtualVariable): # no one is using the returned virtual variable. # now the things are a bit tricky here if isinstance(stmt.src, Call): # replace this assignment statement with a call statement stmt = stmt.src elif isinstance(stmt.src, Convert) and isinstance(stmt.src.operand, Call): # the convert is useless now stmt = stmt.src.operand else: # we can't change this stmt at all because it has an expression with Calls inside pass else: # no calls. remove it simplified = True continue elif isinstance(stmt, Call): codeloc = CodeLocation(block.addr, idx, ins_addr=stmt.ins_addr, block_idx=block.idx) if codeloc in self._calls_to_remove: # this call can be removed simplified = True continue if stmt.ret_expr is not None or stmt.fp_ret_expr is not None: # both the return expr and the fp_ret_expr are not used stmt = stmt.copy() stmt.ret_expr = None stmt.fp_ret_expr = None simplified = True else: # Should not happen! raise NotImplementedError new_statements.append(stmt) new_block = block.copy() new_block.statements = new_statements self.blocks[old_block] = new_block return simplified def _find_cyclic_dependent_phis_and_dirty_vvars(self, rd: SRDAModel) -> set[int]: blocks_dict = {(bb.addr, bb.idx): bb for bb in self.func_graph} # find dirty vvars and vexccall vvars dirty_vvar_ids = set() for bb in self.func_graph: for stmt in bb.statements: if ( isinstance(stmt, Assignment) and isinstance(stmt.dst, VirtualVariable) and stmt.dst.was_reg and isinstance(stmt.src, (DirtyExpression, VEXCCallExpression)) ): dirty_vvar_ids.add(stmt.dst.varid) phi_and_dirty_vvar_ids = rd.phi_vvar_ids | dirty_vvar_ids vvar_used_by: dict[int, set[int]] = defaultdict(set) for var_id in phi_and_dirty_vvar_ids: if var_id in rd.phivarid_to_varids: for used_by_varid in rd.phivarid_to_varids[var_id]: vvar_used_by[used_by_varid].add(var_id) vvar = rd.varid_to_vvar[var_id] used_by = set() for used_vvar, loc in rd.all_vvar_uses[vvar]: if used_vvar is None: # no explicit reference used_by.add(None) else: stmt = blocks_dict[loc.block_addr, loc.block_idx].statements[loc.stmt_idx] if isinstance(stmt, Assignment) and isinstance(stmt.dst, VirtualVariable): used_by.add(stmt.dst.varid) else: used_by.add(None) vvar_used_by[var_id] |= used_by g = networkx.DiGraph() dummy_vvar_id = -1 for var_id, used_by_initial in vvar_used_by.items(): for u in used_by_initial: if u is None: # we can't have None in networkx.DiGraph g.add_edge(var_id, dummy_vvar_id) else: g.add_edge(var_id, u) cyclic_dependent_phi_varids = set() for scc in networkx.strongly_connected_components(g): if len(scc) == 1: continue bail = False for varid in scc: # if this vvar is a phi var, ensure this vvar is not used by anything else outside the scc if varid in rd.phi_vvar_ids: succs = list(g.successors(varid)) if any(succ_varid not in scc for succ_varid in succs): bail = True break if bail: continue if all(varid in phi_and_dirty_vvar_ids for varid in scc): cyclic_dependent_phi_varids |= set(scc) return cyclic_dependent_phi_varids # # Rewriting ccalls # def _rewrite_ccalls(self): rewriter_cls = CCALL_REWRITERS.get(self.project.arch.name, None) if rewriter_cls is None: return False walker = None class _any_update: """ Dummy class for storing if any result has been updated. """ v = False def _handle_expr(expr_idx: int, expr: Expression, stmt_idx: int, stmt: Statement, block) -> Expression | None: if isinstance(expr, VEXCCallExpression): rewriter = rewriter_cls(expr, self.project.arch) if rewriter.result is not None: _any_update.v = True return rewriter.result return None return AILBlockWalker._handle_expr(walker, expr_idx, expr, stmt_idx, stmt, block) blocks_by_addr_and_idx = {(node.addr, node.idx): node for node in self.func_graph.nodes()} walker = AILBlockWalker() walker._handle_expr = _handle_expr updated = False for block in blocks_by_addr_and_idx.values(): _any_update.v = False old_block = block.copy() walker.walk(block) if _any_update.v: self.blocks[old_block] = block updated = True return updated @staticmethod def _statement_has_call_exprs(stmt: Statement) -> bool: def _handle_callexpr(expr_idx, expr, stmt_idx, stmt, block): # pylint:disable=unused-argument raise HasCallNotification walker = AILBlockWalker() walker.expr_handlers[Call] = _handle_callexpr try: walker.walk_statement(stmt) except HasCallNotification: return True return False @staticmethod def _expression_has_call_exprs(expr: Expression) -> bool: def _handle_callexpr(expr_idx, expr, stmt_idx, stmt, block): # pylint:disable=unused-argument raise HasCallNotification walker = AILBlockWalker() walker.expr_handlers[Call] = _handle_callexpr try: walker.walk_expression(expr) except HasCallNotification: return True return False @staticmethod def _count_calls_in_supernodeblocks(blocks: list[Block], start: CodeLocation, end: CodeLocation) -> int: """ Count the number of call statements in a list of blocks for a single super block between two given code locations (exclusive). """ calls = 0 started = False for b in blocks: if b.addr == start.block_addr: started = True continue if b.addr == end.block_addr: started = False continue if started and b.statements and isinstance(b.statements[-1], Call): calls += 1 return calls @staticmethod def _exprs_contain_vvar(exprs: Iterable[Expression], vvar_ids: set[int]) -> bool: def _handle_VirtualVariable(expr_idx, expr, stmt_idx, stmt, block): # pylint:disable=unused-argument if expr.varid in vvar_ids: raise HasVVarNotification walker = AILBlockWalker() walker.expr_handlers[VirtualVariable] = _handle_VirtualVariable for expr in exprs: try: walker.walk_expression(expr) except HasVVarNotification: return True return False
AnalysesHub.register_default("AILSimplifier", AILSimplifier)