from __future__ import annotations
from typing import Any
import pyvex
import claripy
from angr.utils.bits import s2u, u2s
from angr.block import Block
from angr.analyses.analysis import Analysis
from angr.analyses import AnalysesHub
from angr.knowledge_plugins.functions import Function
from angr.codenode import BlockNode, HookNode
from angr.engines.light import SimEngineNostmtVEX, SimEngineLight, SpOffset, RegisterOffset
from angr.calling_conventions import SimRegArg, SimStackArg, default_cc
from angr.sim_type import SimTypeBottom
from .utils import is_sane_register_variable
class FactCollectorState:
"""
The abstract state for FactCollector.
"""
__slots__ = (
"bp_value",
"callee_stored_regs",
"reg_reads",
"reg_writes",
"simple_stack",
"sp_value",
"stack_reads",
"stack_writes",
"tmps",
)
def __init__(self):
self.tmps = {}
self.simple_stack = {}
self.callee_stored_regs: dict[int, int] = {} # reg offset -> stack offset
self.reg_reads = {}
self.reg_writes: set[int] = set()
self.stack_reads = {}
self.stack_writes: set[int] = set()
self.sp_value = 0
self.bp_value = 0
def register_read(self, offset: int, size_in_bytes: int):
if offset in self.reg_writes:
return
if offset not in self.reg_reads:
self.reg_reads[offset] = size_in_bytes
else:
self.reg_reads[offset] = max(self.reg_reads[offset], size_in_bytes)
def register_written(self, offset: int, size_in_bytes: int):
for o in range(size_in_bytes):
self.reg_writes.add(offset + o)
def stack_read(self, offset: int, size_in_bytes: int):
if offset in self.stack_writes:
return
if offset not in self.stack_reads:
self.stack_reads[offset] = size_in_bytes
else:
self.stack_reads[offset] = max(self.stack_reads[offset], size_in_bytes)
def stack_written(self, offset: int, size_int_bytes: int):
for o in range(size_int_bytes):
self.stack_writes.add(offset + o)
def copy(self, with_tmps: bool = False) -> FactCollectorState:
new_state = FactCollectorState()
new_state.reg_reads = self.reg_reads.copy()
new_state.stack_reads = self.stack_reads.copy()
new_state.stack_writes = self.stack_writes.copy()
new_state.reg_writes = self.reg_writes.copy()
new_state.callee_stored_regs = self.callee_stored_regs.copy()
new_state.sp_value = self.sp_value
new_state.bp_value = self.bp_value
new_state.simple_stack = self.simple_stack.copy()
if with_tmps:
new_state.tmps = self.tmps.copy()
return new_state
binop_handler = SimEngineNostmtVEX[FactCollectorState, claripy.ast.BV, FactCollectorState].binop_handler
class SimEngineFactCollectorVEX(
SimEngineNostmtVEX[FactCollectorState, SpOffset | RegisterOffset | int, None],
SimEngineLight[type[FactCollectorState], SpOffset | RegisterOffset | int, Block, None],
):
"""
THe engine for FactCollector.
"""
def __init__(self, project, bp_as_gpr: bool):
self.bp_as_gpr = bp_as_gpr
super().__init__(project)
def _process_block_end(self, stmt_result: list, whitelist: set[int] | None) -> None:
if self.block.vex.jumpkind == "Ijk_Call":
self.state.register_written(self.arch.ret_offset, self.arch.bytes)
def _top(self, bits: int):
return None
def _is_top(self, expr: Any) -> bool:
raise NotImplementedError
def _handle_conversion(self, from_size: int, to_size: int, signed: bool, operand: pyvex.IRExpr) -> Any:
return None
def _handle_stmt_Put(self, stmt):
v = self._expr(stmt.data)
if stmt.offset == self.arch.sp_offset and isinstance(v, SpOffset):
self.state.sp_value = v.offset
elif stmt.offset == self.arch.bp_offset and isinstance(v, SpOffset):
self.state.bp_value = v.offset
else:
self.state.register_written(stmt.offset, stmt.data.result_size(self.tyenv) // self.arch.byte_width)
def _handle_stmt_Store(self, stmt: pyvex.IRStmt.Store):
addr = self._expr(stmt.addr)
if isinstance(addr, SpOffset):
self.state.stack_written(addr.offset, stmt.data.result_size(self.tyenv) // self.arch.byte_width)
data = self._expr(stmt.data)
if isinstance(data, RegisterOffset) and not isinstance(data, SpOffset):
# push reg; we record the stored register as well as the stack slot offset
self.state.callee_stored_regs[data.reg] = u2s(addr.offset, self.arch.bits)
if isinstance(data, SpOffset):
self.state.simple_stack[addr.offset] = data
def _handle_stmt_WrTmp(self, stmt: pyvex.IRStmt.WrTmp):
v = self._expr(stmt.data)
if v is not None:
self.state.tmps[stmt.tmp] = v
def _handle_expr_Const(self, expr: pyvex.IRExpr.Const):
return expr.con.value
def _handle_expr_GSPTR(self, expr):
return None
def _handle_expr_Get(self, expr) -> SpOffset | None:
if expr.offset == self.arch.sp_offset:
return SpOffset(self.arch.bits, self.state.sp_value, is_base=False)
if expr.offset == self.arch.bp_offset and not self.bp_as_gpr:
return SpOffset(self.arch.bits, self.state.bp_value, is_base=False)
bits = expr.result_size(self.tyenv)
self.state.register_read(expr.offset, bits // self.arch.byte_width)
return RegisterOffset(bits, expr.offset, 0)
def _handle_expr_GetI(self, expr):
return None
def _handle_expr_ITE(self, expr):
return None
def _handle_expr_Load(self, expr):
addr = self._expr(expr.addr)
if isinstance(addr, SpOffset):
self.state.stack_read(addr.offset, expr.result_size(self.tyenv) // self.arch.byte_width)
return self.state.simple_stack.get(addr.offset)
return None
def _handle_expr_RdTmp(self, expr):
return self.state.tmps.get(expr.tmp, None)
def _handle_expr_VECRET(self, expr):
return None
@binop_handler
def _handle_binop_Add(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset) and isinstance(op1, int):
return SpOffset(op0.bits, s2u(op0.offset + op1, op0.bits), is_base=op0.is_base)
if isinstance(op1, SpOffset) and isinstance(op0, int):
return SpOffset(op1.bits, s2u(op1.offset + op0, op1.bits), is_base=op1.is_base)
return None
@binop_handler
def _handle_binop_Sub(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset) and isinstance(op1, int):
return SpOffset(op0.bits, s2u(op0.offset - op1, op0.bits), is_base=op0.is_base)
if isinstance(op1, SpOffset) and isinstance(op0, int):
return SpOffset(op1.bits, s2u(op1.offset - op0, op1.bits), is_base=op1.is_base)
return None
@binop_handler
def _handle_binop_And(self, expr):
op0, op1 = self._expr(expr.args[0]), self._expr(expr.args[1])
if isinstance(op0, SpOffset):
return op0
if isinstance(op1, SpOffset):
return op1
return None
[文档]
class FactCollector(Analysis):
"""
An extremely fast analysis that extracts necessary facts of a function for CallingConventionAnalysis to make
decision on the calling convention and prototype of a function.
"""
[文档]
def __init__(self, func: Function, max_depth: int = 5):
self.function = func
self._max_depth = max_depth
self.input_args: list[SimRegArg | SimStackArg] | None = None
self.retval_size: int | None = None
self._analyze()
def _analyze(self):
# breadth-first search using function graph, collect registers and stack variables that are written to as well
# as read from, until max_depth is reached
end_states = self._analyze_startpoint()
self._analyze_endpoints_for_retval_size()
callee_restored_regs = self._analyze_endpoints_for_restored_regs()
self._determine_input_args(end_states, callee_restored_regs)
def _analyze_startpoint(self):
func_graph = self.function.transition_graph
startpoint = self.function.startpoint
bp_as_gpr = self.function.info.get("bp_as_gpr", False)
engine = SimEngineFactCollectorVEX(self.project, bp_as_gpr)
init_state = FactCollectorState()
if self.project.arch.call_pushes_ret:
init_state.sp_value = self.project.arch.bytes
init_state.bp_value = init_state.sp_value
traversed = set()
queue: list[tuple[int, FactCollectorState, BlockNode | HookNode | Function, BlockNode | HookNode | None]] = [
(0, init_state, startpoint, None)
]
end_states: list[FactCollectorState] = []
while queue:
depth, state, node, retnode = queue.pop(0)
traversed.add(node)
if depth > self._max_depth:
end_states.append(state)
break
if isinstance(node, BlockNode) and node.size == 0:
continue
if isinstance(node, HookNode):
# attempt to convert it into a function
if self.kb.functions.contains_addr(node.addr):
node = self.kb.functions.get_by_addr(node.addr)
else:
continue
if isinstance(node, Function):
if node.calling_convention is not None and node.prototype is not None:
# consume args and overwrite the return register
self._handle_function(state, node)
if node.returning is False or retnode is None:
# the function call does not return
end_states.append(state)
else:
# enqueue the retnode, but we don't increment the depth
new_state = state.copy()
if self.project.arch.call_pushes_ret:
new_state.sp_value += self.project.arch.bytes
queue.append((depth, new_state, retnode, None))
continue
block = self.project.factory.block(node.addr, size=node.size)
engine.process(state, block=block)
successor_added = False
call_succ, ret_succ = None, None
for _, succ, data in func_graph.out_edges(node, data=True):
edge_type = data.get("type")
outside = data.get("outside", False)
if succ not in traversed and depth + 1 <= self._max_depth:
if edge_type == "fake_return":
ret_succ = succ
elif edge_type == "transition" and not outside:
successor_added = True
queue.append((depth + 1, state.copy(), succ, None))
elif edge_type == "call" or (edge_type == "transition" and outside):
# a call or a tail-call
if not isinstance(succ, Function):
if self.kb.functions.contains_addr(succ.addr):
succ = self.kb.functions.get_by_addr(succ.addr)
else:
# not sure who we are calling
continue
call_succ = succ
if call_succ is not None:
successor_added = True
queue.append((depth + 1, state.copy(), call_succ, ret_succ))
if not successor_added:
end_states.append(state)
return end_states
def _handle_function(self, state: FactCollectorState, func: Function) -> None:
try:
arg_locs = func.calling_convention.arg_locs(func.prototype)
except (TypeError, ValueError):
return
if None in arg_locs:
return
for arg_loc in arg_locs:
for loc in arg_loc.get_footprint():
if isinstance(loc, SimRegArg):
state.register_read(self.project.arch.registers[loc.reg_name][0] + loc.reg_offset, loc.size)
elif isinstance(loc, SimStackArg):
sp_value = state.sp_value
if sp_value is not None:
state.stack_read(sp_value + loc.stack_offset, loc.size)
# clobber caller-saved regs
for reg_name in func.calling_convention.CALLER_SAVED_REGS:
offset = self.project.arch.registers[reg_name][0]
state.register_written(offset, self.project.arch.registers[reg_name][1])
def _analyze_endpoints_for_retval_size(self):
"""
Analyze all endpoints to determine the return value size.
"""
func_graph = self.function.transition_graph
cc_cls = default_cc(
self.project.arch.name, platform=self.project.simos.name if self.project.simos is not None else None
)
cc = cc_cls(self.project.arch)
if isinstance(cc.RETURN_VAL, SimRegArg):
retreg_offset = cc.RETURN_VAL.check_offset(self.project.arch)
else:
return
retval_sizes = []
for endpoint in self.function.endpoints:
traversed = set()
queue: list[tuple[int, BlockNode | HookNode]] = [(0, endpoint)]
while queue:
depth, node = queue.pop(0)
traversed.add(node)
if depth > 3:
break
if isinstance(node, BlockNode) and node.size == 0:
continue
if isinstance(node, HookNode):
# attempt to convert it into a function
if self.kb.functions.contains_addr(node.addr):
node = self.kb.functions.get_by_addr(node.addr)
else:
continue
if isinstance(node, Function):
if (
node.calling_convention is not None
and node.prototype is not None
and node.prototype.returnty is not None
and not isinstance(node.prototype.returnty, SimTypeBottom)
):
# assume the function overwrites the return variable
retval_size = (
node.prototype.returnty.with_arch(self.project.arch).size // self.project.arch.byte_width
)
retval_sizes.append(retval_size)
continue
block = self.project.factory.block(node.addr, size=node.size)
# scan the block statements backwards to find writes to the return value register
retval_size = None
for stmt in reversed(block.vex.statements):
if isinstance(stmt, pyvex.IRStmt.Put):
size = stmt.data.result_size(block.vex.tyenv) // self.project.arch.byte_width
if stmt.offset == retreg_offset:
retval_size = max(size, 1)
if retval_size is not None:
retval_sizes.append(retval_size)
continue
for pred, _, data in func_graph.in_edges(node, data=True):
edge_type = data.get("type")
if pred not in traversed and depth + 1 <= self._max_depth:
if edge_type == "fake_return":
continue
if edge_type in {"transition", "call"}:
queue.append((depth + 1, pred))
self.retval_size = max(retval_sizes) if retval_sizes else None
def _analyze_endpoints_for_restored_regs(self):
"""
Analyze all endpoints to determine the restored registers.
"""
func_graph = self.function.transition_graph
callee_restored_regs = set()
for endpoint in self.function.endpoints:
traversed = set()
queue: list[tuple[int, BlockNode | HookNode]] = [(0, endpoint)]
while queue:
depth, node = queue.pop(0)
traversed.add(node)
if depth > 3:
break
if isinstance(node, BlockNode) and node.size == 0:
continue
if isinstance(node, (HookNode, Function)):
continue
block = self.project.factory.block(node.addr, size=node.size)
# scan the block statements backwards to find all statements that restore registers from the stack
tmps = {}
for stmt in block.vex.statements:
if isinstance(stmt, pyvex.IRStmt.WrTmp):
if isinstance(stmt.data, pyvex.IRExpr.Get) and stmt.data.offset in {
self.project.arch.bp_offset,
self.project.arch.sp_offset,
}:
tmps[stmt.tmp] = "sp"
elif (
isinstance(stmt.data, pyvex.IRExpr.Load)
and isinstance(stmt.data.addr, pyvex.IRExpr.RdTmp)
and tmps.get(stmt.data.addr.tmp) == "sp"
):
tmps[stmt.tmp] = "stack_value"
elif isinstance(stmt.data, pyvex.IRExpr.Const):
tmps[stmt.tmp] = "const"
elif isinstance(stmt.data, pyvex.IRExpr.Binop) and ( # noqa:SIM102
stmt.data.op.startswith("Iop_Add") or stmt.data.op.startswith("Iop_Sub")
):
if (
isinstance(stmt.data.args[0], pyvex.IRExpr.RdTmp)
and tmps.get(stmt.data.args[0].tmp) == "sp"
) or (
isinstance(stmt.data.args[1], pyvex.IRExpr.RdTmp)
and tmps.get(stmt.data.args[1].tmp) == "sp"
):
tmps[stmt.tmp] = "sp"
if isinstance(stmt, pyvex.IRStmt.Put):
size = stmt.data.result_size(block.vex.tyenv) // self.project.arch.byte_width
# is the data loaded from the stack?
if (
size == self.project.arch.bytes
and isinstance(stmt.data, pyvex.IRExpr.RdTmp)
and tmps.get(stmt.data.tmp) == "stack_value"
):
callee_restored_regs.add(stmt.offset)
for pred, _, data in func_graph.in_edges(node, data=True):
edge_type = data.get("type")
if pred not in traversed and depth + 1 <= self._max_depth and edge_type == "transition":
queue.append((depth + 1, pred))
return callee_restored_regs
def _determine_input_args(self, end_states: list[FactCollectorState], callee_restored_regs: set[int]) -> None:
self.input_args = []
reg_offset_created = set()
callee_saved_regs = set()
callee_saved_reg_stack_offsets = set()
# determine callee-saved registers
for state in end_states:
for reg_offset, stack_offset in state.callee_stored_regs.items():
if reg_offset in callee_restored_regs:
callee_saved_regs.add(reg_offset)
callee_saved_reg_stack_offsets.add(stack_offset)
for state in end_states:
for offset, size in state.reg_reads.items():
if (
offset in reg_offset_created
or offset == self.project.arch.bp_offset
or not is_sane_register_variable(self.project.arch, offset, size)
or offset in callee_saved_regs
):
continue
reg_offset_created.add(offset)
if self.project.arch.name in {"AMD64", "X86"} and size < self.project.arch.bytes:
# use complete registers on AMD64 and X86
reg_name = self.project.arch.translate_register_name(offset, size=self.project.arch.bytes)
arg = SimRegArg(reg_name, self.project.arch.bytes)
else:
reg_name = self.project.arch.translate_register_name(offset, size=size)
arg = SimRegArg(reg_name, size)
self.input_args.append(arg)
stack_offset_created = set()
ret_addr_offset = 0 if not self.project.arch.call_pushes_ret else self.project.arch.bytes
for state in end_states:
for offset, size in state.stack_reads.items():
offset = u2s(offset, self.project.arch.bits)
if offset - ret_addr_offset > 0:
if offset in stack_offset_created or offset in callee_saved_reg_stack_offsets:
continue
stack_offset_created.add(offset)
arg = SimStackArg(offset - ret_addr_offset, size)
self.input_args.append(arg)
AnalysesHub.register_default("FunctionFactCollector", FactCollector)