angr.analyses.s_reaching_definitions.s_reaching_definitions 源代码

from __future__ import annotations

from ailment.block import Block
from ailment.statement import Assignment, Call, Return
from ailment.expression import VirtualVariable
import networkx

from angr.knowledge_plugins.functions import Function
from angr.knowledge_plugins.key_definitions.constants import ObservationPointType
from angr.code_location import CodeLocation, ExternalCodeLocation
from angr.analyses import Analysis, register_analysis
from angr.utils.ssa import get_vvar_uselocs, get_vvar_deflocs, get_tmp_deflocs, get_tmp_uselocs
from angr.calling_conventions import default_cc
from .s_rda_model import SRDAModel
from .s_rda_view import SRDAView


[文档] class SReachingDefinitionsAnalysis(Analysis): """ Constant and expression propagation that only supports SSA AIL graphs. """
[文档] def __init__( # pylint: disable=too-many-positional-arguments self, subject, func_addr: int | None = None, func_graph: networkx.DiGraph[Block] | None = None, func_args: set[VirtualVariable] | None = None, track_tmps: bool = False, ): if isinstance(subject, Block): self.block = subject self.func = None self.mode = "block" elif isinstance(subject, Function): self.block = None self.func = subject self.mode = "function" else: raise TypeError(f"Unsupported subject type {type(subject)}") self.func_graph = func_graph self.func_addr = func_addr if func_addr is not None else self.func.addr if self.func is not None else None self.func_args = func_args self._track_tmps = track_tmps self._bp_as_gpr = False if self.func is not None: self._bp_as_gpr = self.func.info.get("bp_as_gpr", False) self.model = SRDAModel(func_graph, func_args, self.project.arch) self._analyze()
def _analyze(self): match self.mode: case "block": assert self.block is not None blocks = {(self.block.addr, self.block.idx): self.block} case "function": assert self.func_graph is not None blocks = {(block.addr, block.idx): block for block in self.func_graph} case _: raise NotImplementedError phi_vvars = {} # find all vvar definitions vvar_deflocs = get_vvar_deflocs(blocks.values(), phi_vvars=phi_vvars) # find all explicit vvar uses vvar_uselocs = get_vvar_uselocs(blocks.values()) # update vvar definitions using function arguments if self.func_args: for vvar in self.func_args: if vvar not in vvar_deflocs: vvar_deflocs[vvar] = ExternalCodeLocation() self.model.func_args = self.func_args # update model for vvar, defloc in vvar_deflocs.items(): self.model.varid_to_vvar[vvar.varid] = vvar self.model.all_vvar_definitions[vvar] = defloc for vvar_at_use, useloc in vvar_uselocs[vvar.varid]: self.model.all_vvar_uses[vvar].add((vvar_at_use, useloc)) self.model.phi_vvar_ids = {vvar.varid for vvar in phi_vvars} self.model.phivarid_to_varids = {} for vvar, src_vvars in phi_vvars.items(): self.model.phivarid_to_varids[vvar.varid] = { src_vvar.varid for src_vvar in src_vvars if src_vvar is not None } if self.mode == "function": # fix register definitions for arguments defined_vvarids = {vvar.varid for vvar in vvar_deflocs} undefined_vvarids = set(vvar_uselocs.keys()).difference(defined_vvarids) for vvar_id in undefined_vvarids: used_vvar = next(iter(vvar_uselocs[vvar_id]))[0] self.model.varid_to_vvar[used_vvar.varid] = used_vvar self.model.all_vvar_definitions[used_vvar] = ExternalCodeLocation() self.model.all_vvar_uses[used_vvar] |= vvar_uselocs[vvar_id] srda_view = SRDAView(self.model) # fix register uses at call sites # find all implicit vvar uses call_stmt_ids = [] for block in blocks.values(): for stmt_idx, stmt in enumerate(block.statements): if ( # pylint:disable=too-many-boolean-expressions (isinstance(stmt, Call) and stmt.args is None) or (isinstance(stmt, Assignment) and isinstance(stmt.src, Call) and stmt.src.args is None) or (isinstance(stmt, Return) and stmt.ret_exprs and isinstance(stmt.ret_exprs[0], Call)) ): call_stmt_ids.append(((block.addr, block.idx), stmt_idx)) observations = srda_view.observe( [("stmt", insn_stmt_id, ObservationPointType.OP_BEFORE) for insn_stmt_id in call_stmt_ids] ) for key, reg_to_vvarids in observations.items(): _, ((block_addr, block_idx), stmt_idx), _ = key block = blocks[(block_addr, block_idx)] stmt = block.statements[stmt_idx] assert isinstance(stmt, (Call, Assignment, Return)) call = ( stmt if isinstance(stmt, Call) else stmt.src if isinstance(stmt, Assignment) else stmt.ret_exprs[0] ) assert isinstance(call, Call) if call.prototype is None: # without knowing the prototype, we must conservatively add uses to all registers that are # potentially used here if call.calling_convention is not None: cc = call.calling_convention else: # just use all registers in the default calling convention because we don't know anything about # the calling convention yet cc_cls = default_cc(self.project.arch.name) assert cc_cls is not None cc = cc_cls(self.project.arch) codeloc = CodeLocation(block_addr, stmt_idx, block_idx=block_idx, ins_addr=stmt.ins_addr) arg_locs = cc.ARG_REGS for arg_reg_name in arg_locs: reg_offset = self.project.arch.registers[arg_reg_name][0] if reg_offset in reg_to_vvarids: vvarid = reg_to_vvarids[reg_offset] vvar = self.model.varid_to_vvar[vvarid] self.model.all_vvar_uses[vvar].add((None, codeloc)) if self._track_tmps: # track tmps tmp_deflocs = get_tmp_deflocs(blocks.values()) # find all vvar uses tmp_uselocs = get_tmp_uselocs(blocks.values()) # update model for block_loc, d in tmp_deflocs.items(): for tmp_atom, stmt_idx in d.items(): self.model.all_tmp_definitions[block_loc][tmp_atom] = stmt_idx if tmp_atom in tmp_uselocs[block_loc]: for tmp_at_use, use_stmt_idx in tmp_uselocs[block_loc][tmp_atom]: if tmp_atom not in self.model.all_tmp_uses[block_loc]: self.model.all_tmp_uses[block_loc][tmp_atom] = set() self.model.all_tmp_uses[block_loc][tmp_atom].add((tmp_at_use, use_stmt_idx))
register_analysis(SReachingDefinitionsAnalysis, "SReachingDefinitions")