# pylint:disable=arguments-renamed,too-many-boolean-expressions,no-self-use,unused-argument
from __future__ import annotations
from typing import Any
from collections.abc import Callable
from collections import defaultdict
from archinfo import Endness
from ailment.expression import (
Const,
Register,
Expression,
StackBaseOffset,
Convert,
BinaryOp,
VirtualVariable,
Phi,
UnaryOp,
VirtualVariableCategory,
)
from ailment.statement import ConditionalJump, Jump, Assignment
import claripy
from angr.utils.bits import zeroextend_on_demand
from angr.engines.light import SimEngineNostmtAIL
from angr.storage.memory_mixins import (
SimpleInterfaceMixin,
DefaultFillerMixin,
PagedMemoryMixin,
UltraPagesMixin,
)
from angr.code_location import CodeLocation
from angr.errors import SimMemoryMissingError
from .optimization_pass import OptimizationPass, OptimizationPassStage
def _make_binop(compute: Callable[[claripy.ast.BV, claripy.ast.BV], claripy.ast.BV]):
def inner(self, expr: BinaryOp) -> claripy.ast.BV | None:
a, b = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if a is None or b is None:
return None
try:
return compute(a, b)
except (ZeroDivisionError, claripy.ClaripyZeroDivisionError):
return None
return inner
class FasterMemory(
SimpleInterfaceMixin,
DefaultFillerMixin,
UltraPagesMixin,
PagedMemoryMixin,
):
"""
A fast memory model used in InlinedStringTransformationState.
"""
class InlinedStringTransformationState:
"""
The abstract state used in InlinedStringTransformationAILEngine.
"""
def __init__(self, project):
self.arch = project.arch
self.project = project
self.registers = FasterMemory(memory_id="reg")
self.memory = FasterMemory(memory_id="mem")
self.virtual_variables = {}
self.registers.set_state(self)
self.memory.set_state(self)
def _get_weakref(self):
return self
def reg_store(self, reg: Register, value: claripy.ast.BV) -> None:
self.registers.store(
reg.reg_offset, value, size=value.size() // self.arch.byte_width, endness=str(self.arch.register_endness)
)
def reg_load(self, reg: Register) -> claripy.ast.BV | None:
try:
return self.registers.load(
reg.reg_offset, size=reg.size, endness=self.arch.register_endness, fill_missing=False
)
except SimMemoryMissingError:
return None
def mem_store(self, addr: int, value: claripy.ast.Bits, endness: str) -> None:
self.memory.store(addr, value, size=value.size() // self.arch.byte_width, endness=endness)
def mem_load(self, addr: int, size: int, endness) -> claripy.ast.BV | None:
try:
return self.memory.load(addr, size=size, endness=str(endness), fill_missing=False)
except SimMemoryMissingError:
return None
def vvar_store(self, vvar: VirtualVariable, value: claripy.ast.Bits | None) -> None:
self.virtual_variables[vvar.varid] = value
def vvar_load(self, vvar: VirtualVariable) -> claripy.ast.BV | None:
if vvar.varid in self.virtual_variables:
return self.virtual_variables[vvar.varid]
return None
class InlinedStringTransformationAILEngine(
SimEngineNostmtAIL[InlinedStringTransformationState, claripy.ast.BV | None, None, None],
):
"""
A simple AIL execution engine
"""
def __init__(self, project, nodes: dict[int, Any], start: int, end: int, step_limit: int):
super().__init__(project)
self.nodes: dict[int, Any] = nodes
self.start: int = start
self.end: int = end
self.step_limit: int = step_limit
self.STACK_BASE = 0x7FFF_FFF0 if self.arch.bits == 32 else 0x7FFF_FFFF_F000
self.MASK = 0xFFFF_FFFF if self.arch.bits == 32 else 0xFFFF_FFFF_FFFF_FFFF
state = InlinedStringTransformationState(project)
self.stack_accesses: defaultdict[int, list[tuple[str, CodeLocation, claripy.ast.Bits]]] = defaultdict(list)
self.finished: bool = False
i = 0
self.last_pc = None
self.pc = self.start
while i < self.step_limit:
if self.pc not in self.nodes:
# jumped to a node that we do not know about
break
block = self.nodes[self.pc]
self._process(state, block=block, whitelist=None)
if self.pc is None:
# not sure where to jump...
break
if self.pc == self.end:
# we reach the end of execution!
self.finished = True
break
i += 1
def _top(self, bits):
assert False, "Should not be reachable"
def _is_top(self, expr):
assert False, "Should not be reachable"
def _process_block_end(self, block, stmt_data, whitelist):
pass
def _process_address(self, addr: Expression) -> tuple[int, str] | None:
if isinstance(addr, Const):
assert isinstance(addr.value, int)
return addr.value, "mem"
if isinstance(addr, StackBaseOffset):
return (addr.offset + self.STACK_BASE) & self.MASK, "stack"
if isinstance(addr, BinaryOp) and isinstance(addr.operands[0], StackBaseOffset):
v0_and_type = self._process_address(addr.operands[0])
if v0_and_type is not None:
v0 = v0_and_type[0]
v1 = self._expr(addr.operands[1])
if isinstance(v1, claripy.ast.Bits) and v1.concrete:
return (v0 + v1.concrete_value) & self.MASK, "stack"
return None
def _handle_stmt_Assignment(self, stmt):
if isinstance(stmt.dst, VirtualVariable):
if stmt.dst.was_reg:
val = self._expr(stmt.src)
if isinstance(val, claripy.ast.Bits):
self.state.vvar_store(stmt.dst, val)
elif stmt.dst.was_stack:
addr = (stmt.dst.stack_offset + self.STACK_BASE) & self.MASK
val = self._expr(stmt.src)
if isinstance(val, claripy.ast.BV):
self.state.mem_store(addr, val, self.arch.memory_endness)
# log it
for i in range(val.size() // self.arch.byte_width):
byte_off = i
if self.arch.memory_endness == Endness.LE:
byte_off = val.size() // self.arch.byte_width - i - 1
self.stack_accesses[addr + i].append(("store", self._codeloc(), val.get_byte(byte_off)))
def _handle_stmt_Store(self, stmt):
addr_and_type = self._process_address(stmt.addr)
if addr_and_type is not None:
addr, addr_type = addr_and_type
val = self._expr(stmt.data)
if isinstance(val, claripy.ast.BV):
self.state.mem_store(addr, val, stmt.endness)
# log it
if addr_type == "stack":
for i in range(val.size() // self.arch.byte_width):
byte_off = i
if self.arch.memory_endness == Endness.LE:
byte_off = val.size() // self.arch.byte_width - i - 1
self.stack_accesses[addr + i].append(("store", self._codeloc(), val.get_byte(byte_off)))
def _handle_stmt_Jump(self, stmt):
self.last_pc = self.pc
if isinstance(stmt.target, Const):
self.pc = stmt.target.value
else:
self.pc = None
def _handle_stmt_ConditionalJump(self, stmt):
self.last_pc = self.pc
self.pc = None
if isinstance(stmt.true_target, Const) and isinstance(stmt.false_target, Const):
cond = self._expr(stmt.condition)
if cond is not None:
if isinstance(cond, claripy.ast.Bits) and cond.concrete_value == 1:
self.pc = stmt.true_target.value
elif isinstance(cond, claripy.ast.Bits) and cond.concrete_value == 0:
self.pc = stmt.false_target.value
def _handle_expr_Const(self, expr):
if isinstance(expr.value, int):
return claripy.BVV(expr.value, expr.bits)
return None
def _handle_expr_Load(self, expr):
addr_and_type = self._process_address(expr.addr)
if addr_and_type is not None:
addr, addr_type = addr_and_type
v = self.state.mem_load(addr, expr.size, expr.endness)
# log it
if addr_type == "stack" and isinstance(v, claripy.ast.BV):
for i in range(expr.size):
byte_off = i
if self.arch.memory_endness == Endness.LE:
byte_off = expr.size - i - 1
self.stack_accesses[addr + i].append(("load", self._codeloc(), v.get_byte(byte_off)))
return v
return None
def _handle_expr_Register(self, expr: Register):
return self.state.reg_load(expr)
def _handle_expr_VirtualVariable(self, expr: VirtualVariable):
if expr.was_stack:
addr = (expr.stack_offset + self.STACK_BASE) & self.MASK
v = self.state.mem_load(addr, expr.size, self.arch.memory_endness)
if v is not None:
# log it
for i in range(expr.size):
byte_off = i
if self.arch.memory_endness == Endness.LE:
byte_off = expr.size - i - 1
self.stack_accesses[addr + i].append(("load", self._codeloc(), v.get_byte(byte_off)))
return v
if expr.was_reg:
return self.state.vvar_load(expr)
return None
def _handle_expr_Phi(self, expr: Phi):
for src, vvar in expr.src_and_vvars:
if src[0] == self.last_pc and vvar is not None:
return self.state.vvar_load(vvar)
return None
def _handle_unop_Neg(self, expr: UnaryOp):
v = self._expr(expr.operand)
if isinstance(v, claripy.ast.Bits):
return -v
return None
def _handle_unop_Not(self, expr: UnaryOp):
v = self._expr(expr.operand)
if isinstance(v, claripy.ast.Bits):
return ~v
return None
def _handle_unop_BitwiseNeg(self, expr: UnaryOp):
v = self._expr(expr.operand)
if isinstance(v, claripy.ast.Bits):
return ~v
return None
def _handle_unop_Default(self, expr: UnaryOp):
return None
_handle_unop_Clz = _handle_unop_Default
_handle_unop_Ctz = _handle_unop_Default
_handle_unop_Dereference = _handle_unop_Default
_handle_unop_Reference = _handle_unop_Default
_handle_unop_GetMSBs = _handle_unop_Default
_handle_unop_unpack = _handle_unop_Default
_handle_unop_Sqrt = _handle_unop_Default
_handle_unop_RSqrtEst = _handle_unop_Default
def _handle_expr_Convert(self, expr: Convert):
v = self._expr(expr.operand)
if isinstance(v, claripy.ast.Bits):
if expr.to_bits > expr.from_bits:
if not expr.is_signed:
return claripy.ZeroExt(expr.to_bits - expr.from_bits, v)
return claripy.SignExt(expr.to_bits - expr.from_bits, v)
if expr.to_bits < expr.from_bits:
return claripy.Extract(expr.to_bits - 1, 0, v)
return v
return None
def _handle_binop_CmpEQ(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value == op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_binop_CmpNE(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value != op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_binop_CmpLT(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value < op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_binop_CmpLE(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value <= op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_binop_CmpGT(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value > op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_binop_CmpGE(self, expr):
op0, op1 = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if isinstance(op0, claripy.ast.Bits) and isinstance(op1, claripy.ast.Bits) and op0.concrete and op1.concrete:
return claripy.BVV(1, 1) if op0.concrete_value >= op1.concrete_value else claripy.BVV(0, 1)
return None
def _handle_stmt_Call(self, stmt):
pass
def _handle_expr_Call(self, expr):
pass
def _handle_expr_BasePointerOffset(self, expr):
return None
def _handle_expr_DirtyExpression(self, expr):
return None
def _handle_expr_ITE(self, expr):
return None
def _handle_expr_MultiStatementExpression(self, expr):
return None
def _handle_expr_Reinterpret(self, expr):
return None
def _handle_expr_StackBaseOffset(self, expr):
return None
def _handle_expr_Tmp(self, expr):
try:
return self.tmps[expr.tmp_idx]
except KeyError:
return None
def _handle_expr_VEXCCallExpression(self, expr):
return None
def _handle_binop_Default(self, expr):
self._expr(expr.operands[0])
self._expr(expr.operands[1])
_handle_binop_Add = _make_binop(lambda a, b: a + b)
_handle_binop_And = _make_binop(lambda a, b: a & b)
_handle_binop_Concat = _make_binop(lambda a, b: a.concat(b))
_handle_binop_Div = _make_binop(lambda a, b: a // b)
_handle_binop_LogicalAnd = _make_binop(lambda a, b: a & b)
_handle_binop_LogicalOr = _make_binop(lambda a, b: a | b)
_handle_binop_Mod = _make_binop(lambda a, b: a % b)
_handle_binop_Mul = _make_binop(lambda a, b: a * b)
_handle_binop_Or = _make_binop(lambda a, b: a | b)
_handle_binop_Rol = _make_binop(lambda a, b: claripy.RotateLeft(a, zeroextend_on_demand(a, b)))
_handle_binop_Ror = _make_binop(lambda a, b: claripy.RotateRight(a, zeroextend_on_demand(a, b)))
_handle_binop_Sar = _make_binop(lambda a, b: a >> zeroextend_on_demand(a, b))
_handle_binop_Shl = _make_binop(lambda a, b: a << zeroextend_on_demand(a, b))
_handle_binop_Shr = _make_binop(lambda a, b: a.LShR(zeroextend_on_demand(a, b)))
_handle_binop_Sub = _make_binop(lambda a, b: a - b)
_handle_binop_Xor = _make_binop(lambda a, b: a ^ b)
def _handle_binop_Mull(self, expr):
a, b = self._expr(expr.operands[0]), self._expr(expr.operands[1])
if a is None or b is None:
return None
xt = a.size()
if expr.signed:
return a.sign_extend(xt) * b.sign_extend(xt)
return a.zero_extend(xt) * b.zero_extend(xt)
_handle_binop_AddF = _handle_binop_Default
_handle_binop_AddV = _handle_binop_Default
_handle_binop_Carry = _handle_binop_Default
_handle_binop_CmpF = _handle_binop_Default
_handle_binop_DivF = _handle_binop_Default
_handle_binop_DivV = _handle_binop_Default
_handle_binop_InterleaveLOV = _handle_binop_Default
_handle_binop_InterleaveHIV = _handle_binop_Default
_handle_binop_CasCmpEQ = _handle_binop_Default
_handle_binop_CasCmpNE = _handle_binop_Default
_handle_binop_ExpCmpNE = _handle_binop_Default
_handle_binop_SarNV = _handle_binop_Default
_handle_binop_ShrNV = _handle_binop_Default
_handle_binop_ShlNV = _handle_binop_Default
_handle_binop_CmpEQV = _handle_binop_Default
_handle_binop_CmpNEV = _handle_binop_Default
_handle_binop_CmpGEV = _handle_binop_Default
_handle_binop_CmpGTV = _handle_binop_Default
_handle_binop_CmpLEV = _handle_binop_Default
_handle_binop_CmpLTV = _handle_binop_Default
_handle_binop_MulF = _handle_binop_Default
_handle_binop_MulV = _handle_binop_Default
_handle_binop_MulHiV = _handle_binop_Default
_handle_binop_SBorrow = _handle_binop_Default
_handle_binop_SCarry = _handle_binop_Default
_handle_binop_SubF = _handle_binop_Default
_handle_binop_SubV = _handle_binop_Default
_handle_binop_MinV = _handle_binop_Default
_handle_binop_MaxV = _handle_binop_Default
_handle_binop_QAddV = _handle_binop_Default
_handle_binop_QNarrowBinV = _handle_binop_Default
_handle_binop_PermV = _handle_binop_Default
_handle_binop_Set = _handle_binop_Default
class InlineStringTransformationDescriptor:
"""
Describes an instance of inline string transformation.
"""
def __init__(self, store_block, loop_body, stack_accesses, beginning_stack_offset):
self.store_block = store_block
self.loop_body = loop_body
self.stack_accesses = stack_accesses
self.beginning_stack_offset = beginning_stack_offset