from __future__ import annotations
from typing import Any
from collections.abc import Callable
from collections import defaultdict
import networkx
import pyvex
import claripy
from angr.utils.bits import s2u
from angr.block import Block
from angr.analyses.analysis import Analysis
from angr.analyses import AnalysesHub
from angr.knowledge_plugins.functions import Function
from angr.codenode import BlockNode, HookNode
from angr.engines.light import SimEngineNostmtVEX, SimEngineLight, SpOffset, RegisterOffset
from angr.calling_conventions import SimStackArg, default_cc
from angr.analyses.propagator.vex_vars import VEXReg, VEXTmp
class SV:
"""
SizedValue: A faster implementation of claripy.ast.BV.
"""
__slots__ = ("bits", "value")
def __init__(self, value, bits):
self.value = value
self.bits = bits
def __eq__(self, other):
return isinstance(other, SV) and self.value == other.value and self.bits == other.bits
def __hash__(self):
return hash((self.value, self.bits))
class FCPState:
"""
The abstract state for FastConstantPropagation.
"""
__slots__ = (
"bp_value",
"callee_stored_regs",
"regs",
"simple_stack",
"sp_value",
"stack",
"tmps",
)
def __init__(self):
self.tmps = {}
self.simple_stack = {}
self.regs: dict[int, SV] = {}
self.stack: dict[int, SV] = {}
self.sp_value = 0
self.bp_value = 0
def register_read(self, offset, size_in_bytes: int) -> int | None:
if offset in self.regs:
v = self.regs[offset]
if v.bits == size_in_bytes * 8:
return v.value
return None
def register_written(self, offset: int, size_in_bytes: int, value: int | None):
if value is None:
to_remove = set()
for off, v in self.regs.items():
if (off <= offset < off + v.bits // 8) or (offset <= off < offset + size_in_bytes):
to_remove.add(off)
for off in to_remove:
del self.regs[off]
else:
self.regs[offset] = SV(value, size_in_bytes * 8)
def stack_read(self, offset: int, size_int_bytes: int) -> int | None:
if offset in self.stack:
v = self.stack[offset]
if v.bits == size_int_bytes * 8:
return v.value
return None
def stack_written(self, offset: int, size_int_bytes: int, value: int | None):
if value is None:
to_remove = set()
for off, v in self.stack.items():
if (off <= offset < off + v.bits // 8) or (offset <= off < offset + size_int_bytes):
to_remove.add(off)
for off in to_remove:
del self.stack[off]
else:
self.stack[offset] = SV(value, size_int_bytes * 8)
def copy(self, with_tmps: bool = False) -> FCPState:
new_state = FCPState()
new_state.stack = self.stack.copy()
new_state.regs = self.regs.copy()
new_state.sp_value = self.sp_value
new_state.bp_value = self.bp_value
new_state.simple_stack = self.simple_stack.copy()
if with_tmps:
new_state.tmps = self.tmps.copy()
return new_state
binop_handler = SimEngineNostmtVEX[FCPState, claripy.ast.BV, FCPState].binop_handler
class SimEngineFCPVEX(
SimEngineNostmtVEX[FCPState, SpOffset | RegisterOffset | int, None],
SimEngineLight[type[FCPState], SpOffset | RegisterOffset | int, Block, None],
):
"""
THe engine for FastConstantPropagation.
"""
def __init__(self, project, bp_as_gpr: bool, replacements: dict[int, dict], load_callback: Callable | None = None):
self.bp_as_gpr = bp_as_gpr
self.replacements = replacements
self._load_callback = load_callback
self.base_state = None
super().__init__(project)
def _allow_loading(self, addr: int, size: int) -> bool:
if self._load_callback is None:
return True
return self._load_callback(addr, size)
def _process_block_end(self, stmt_result: list, whitelist: set[int] | None) -> None:
if self.block.vex.jumpkind == "Ijk_Call":
self.state.register_written(self.arch.ret_offset, self.arch.bytes, None)
def _top(self, bits: int):
return None
def _is_top(self, expr: Any) -> bool:
raise NotImplementedError
def _handle_conversion(self, from_size: int, to_size: int, signed: bool, operand: pyvex.IRExpr) -> Any:
return None
def _handle_stmt_Put(self, stmt):
v = self._expr(stmt.data)
if stmt.offset == self.arch.sp_offset and isinstance(v, SpOffset):
self.state.sp_value = v.offset
elif stmt.offset == self.arch.bp_offset and not self.bp_as_gpr and isinstance(v, SpOffset):
self.state.bp_value = v.offset
elif isinstance(v, int):
size = stmt.data.result_size(self.tyenv) // self.arch.byte_width
codeloc = self._codeloc()
self.state.register_written(stmt.offset, size, v)
if stmt.offset != self.arch.ip_offset:
self.replacements[codeloc][VEXReg(stmt.offset, size)] = v
else:
size = stmt.data.result_size(self.tyenv) // self.arch.byte_width
self.state.register_written(stmt.offset, size, None)
def _handle_stmt_Store(self, stmt: pyvex.IRStmt.Store):
addr = self._expr(stmt.addr)
if isinstance(addr, SpOffset):
data = self._expr(stmt.data)
if isinstance(data, int):
self.state.stack_written(addr.offset, stmt.data.result_size(self.tyenv) // self.arch.byte_width, data)
else:
self.state.stack_written(addr.offset, stmt.data.result_size(self.tyenv) // self.arch.byte_width, None)
def _handle_stmt_WrTmp(self, stmt: pyvex.IRStmt.WrTmp):
if isinstance(stmt.data, pyvex.IRExpr.Binop) and not (
stmt.data.op.startswith("Iop_Add")
or stmt.data.op.startswith("Iop_Sub")
or stmt.data.op.startswith("Iop_And")
):
return
v = self._expr(stmt.data)
if v is not None:
self.state.tmps[stmt.tmp] = v
if isinstance(v, int):
codeloc = self._codeloc()
self.replacements[codeloc][VEXTmp(stmt.tmp)] = v
def _handle_expr_Const(self, expr: pyvex.IRExpr.Const):
return expr.con.value
def _handle_expr_GSPTR(self, expr):
return None
def _handle_expr_Get(self, expr) -> SpOffset | None:
if expr.offset == self.arch.sp_offset:
return SpOffset(self.arch.bits, self.state.sp_value, is_base=False)
if expr.offset == self.arch.bp_offset and not self.bp_as_gpr:
return SpOffset(self.arch.bits, self.state.bp_value, is_base=False)
size = expr.result_size(self.tyenv) // self.arch.byte_width
v = self.state.register_read(expr.offset, size)
if v is not None:
codeloc = self._codeloc()
self.replacements[codeloc][VEXReg(expr.offset, size)] = v
return v
def _handle_expr_GetI(self, expr):
return None
def _handle_expr_ITE(self, expr):
return None
def _handle_expr_Load(self, expr) -> int | SpOffset | None:
addr = self._expr(expr.addr)
if isinstance(addr, SpOffset):
return self.state.stack.get(addr.offset)
if isinstance(addr, int):
size = expr.result_size(self.tyenv) // self.arch.byte_width
if self._allow_loading(addr, size):
# Try loading from the state
if self.base_state is not None:
data = self.base_state.memory.load(addr, size, endness=expr.endness)
if not data.symbolic:
return data.args[0]
else:
try:
return self.project.loader.memory.unpack_word(addr, size=size, endness=expr.endness)
except KeyError:
pass
return None
def _handle_expr_RdTmp(self, expr):
return self.state.tmps.get(expr.tmp, None)
def _dummy_handler(self, expr): # pylint:disable=unused-argument,no-self-use
return None
_handle_expr_VECRET = _dummy_handler
_handle_expr_CCall = _dummy_handler
_handle_expr_Unop = _dummy_handler
_handle_expr_Triop = _dummy_handler
@binop_handler
def _handle_binop_Add(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset) and isinstance(op1, int):
return SpOffset(op0.bits, s2u(op0.offset + op1, op0.bits), is_base=op0.is_base)
if isinstance(op1, SpOffset) and isinstance(op0, int):
return SpOffset(op1.bits, s2u(op1.offset + op0, op1.bits), is_base=op1.is_base)
if isinstance(op0, int) and isinstance(op1, int):
mask = (1 << expr.result_size(self.tyenv)) - 1
return (op0 + op1) & mask
return None
@binop_handler
def _handle_binop_Sub(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset) and isinstance(op1, int):
return SpOffset(op0.bits, s2u(op0.offset - op1, op0.bits), is_base=op0.is_base)
if isinstance(op1, SpOffset) and isinstance(op0, int):
return SpOffset(op1.bits, s2u(op1.offset - op0, op1.bits), is_base=op1.is_base)
if isinstance(op0, int) and isinstance(op1, int):
mask = (1 << expr.result_size(self.tyenv)) - 1
return (op0 - op1) & mask
return None
@binop_handler
def _handle_binop_And(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset):
return op0
if isinstance(op1, SpOffset):
return op1
if isinstance(op0, int) and isinstance(op1, int):
return op0 & op1
return None
[文档]
class FastConstantPropagation(Analysis):
"""
An extremely fast constant propagation analysis that finds function-wide constant values with potentially high
false negative rates.
"""
[文档]
def __init__(
self,
func: Function,
blocks: set[Block] | None = None,
vex_cross_insn_opt: bool = False,
load_callback: Callable | None = None,
):
self.function = func
self._blocks = blocks
self._vex_cross_insn_opt = vex_cross_insn_opt
self._load_callback = load_callback
self.replacements = {}
self._analyze()
def _analyze(self):
# traverse the function graph, collect registers and stack variables that are written to
func_graph = self.function.graph
func_graph_with_callees = self.function.transition_graph
startpoint = self.function.startpoint
bp_as_gpr = self.function.info.get("bp_as_gpr", False)
replacements = defaultdict(dict)
engine = SimEngineFCPVEX(self.project, bp_as_gpr, replacements, load_callback=self._load_callback)
init_state = FCPState()
if self.project.arch.call_pushes_ret:
init_state.sp_value = self.project.arch.bytes
init_state.bp_value = init_state.sp_value
sorted_nodes = reversed(list(networkx.dfs_postorder_nodes(func_graph, startpoint)))
block_addrs = None
if self._blocks:
block_addrs = {b.addr for b in self._blocks}
states: dict[BlockNode, FCPState] = {}
for node in sorted_nodes:
preds = func_graph.predecessors(node)
input_states = [states[pred] for pred in preds if pred in states]
state = init_state.copy() if not input_states else self._merge_states(input_states)
if self._blocks and node.addr not in block_addrs:
# skip this block
states[node] = state
continue
# process the block
if isinstance(node, BlockNode) and node.size == 0:
continue
if isinstance(node, HookNode):
# attempt to convert it into a function
if self.kb.functions.contains_addr(node.addr):
node = self.kb.functions.get_by_addr(node.addr)
else:
continue
if isinstance(node, Function):
if node.calling_convention is not None and node.prototype is not None:
# consume args and overwrite the return register
self._handle_function(state, node)
continue
block = self.project.factory.block(node.addr, size=node.size, cross_insn_opt=self._vex_cross_insn_opt)
engine.process(state, block=block)
# if the node ends with a function call, call _handle_function
succs = list(func_graph_with_callees.successors(node))
if any(isinstance(succ, (Function, HookNode)) for succ in succs):
callee = next(succ for succ in succs if isinstance(succ, (Function, HookNode)))
if isinstance(callee, HookNode):
# attempt to convert it into a function
if self.kb.functions.contains_addr(callee.addr):
callee = self.kb.functions.get_by_addr(callee.addr)
else:
callee = None
state = self._handle_function(state, callee)
states[node] = state
self.replacements = replacements
@staticmethod
def _merge_states(states: list[FCPState]) -> FCPState:
state = FCPState()
to_drop = set()
common_keys = set.intersection(*[set(s.regs) for s in states])
for s in states:
for offset, value in s.regs.items():
if offset in common_keys:
if offset in state.regs:
if state.regs[offset] != value:
to_drop.add(offset)
else:
state.regs[offset] = value
for offset in to_drop:
del state.regs[offset]
to_drop = set()
common_keys = set.intersection(*[set(s.stack) for s in states])
for s in states:
for offset, value in s.stack.items():
if offset in common_keys:
if offset in state.stack:
if state.stack[offset] != value:
to_drop.add(offset)
else:
state.stack[offset] = value
for offset in to_drop:
del state.stack[offset]
for s in states:
state.sp_value = max(state.sp_value, s.sp_value)
state.bp_value = max(state.bp_value, s.bp_value)
return state
def _handle_function(self, state: FCPState, func: Function | None) -> FCPState:
if func is None or func.calling_convention is None:
cc = default_cc(self.project.arch.name)
else:
cc = func.calling_convention
out_state = state.copy()
if func is not None and func.prototype is not None:
arg_locs = None
try:
arg_locs = cc.arg_locs(func.prototype)
except (TypeError, ValueError):
arg_locs = None
if None in arg_locs:
arg_locs = None
if arg_locs is not None:
for arg_loc in arg_locs:
for loc in arg_loc.get_footprint():
if isinstance(loc, SimStackArg):
sp_value = out_state.sp_value
if sp_value is not None:
out_state.stack_read(sp_value + loc.stack_offset, loc.size)
# clobber caller-saved regs
for reg_name in cc.CALLER_SAVED_REGS:
offset = self.project.arch.registers[reg_name][0]
out_state.register_written(offset, self.project.arch.registers[reg_name][1], None)
return out_state
AnalysesHub.register_default("FastConstantPropagation", FastConstantPropagation)