angr.analyses.decompiler.condition_processor 源代码

from __future__ import annotations
from collections import defaultdict, OrderedDict
from typing import Any
from collections.abc import Callable
from collections.abc import Generator
import operator
import logging

import ailment
import claripy
import networkx
from unique_log_filter import UniqueLogFilter


from angr.utils.graph import GraphUtils
from angr.utils.lazy_import import lazy_import
from angr.utils import is_pyinstaller
from angr.utils.graph import dominates, inverted_idoms
from angr.block import Block, BlockNode
from angr.errors import AngrRuntimeError
from .peephole_optimizations import InvertNegatedLogicalConjunctionsAndDisjunctions
from .structuring.structurer_nodes import (
    MultiNode,
    EmptyBlockNotice,
    SequenceNode,
    CodeNode,
    SwitchCaseNode,
    BreakNode,
    ConditionalBreakNode,
    LoopNode,
    ConditionNode,
    ContinueNode,
    CascadingConditionNode,
    IncompleteSwitchCaseNode,
)
from .graph_region import GraphRegion
from .utils import first_nonlabel_nonphi_statement, peephole_optimize_expr

if is_pyinstaller():
    # PyInstaller is not happy with lazy import
    import sympy
else:
    sympy = lazy_import("sympy")


l = logging.getLogger(__name__)
l.addFilter(UniqueLogFilter())


_UNIFIABLE_COMPARISONS = {
    "__ne__",
    "__gt__",
    "__ge__",
    "UGT",
    "UGE",
    "SGT",
    "SGE",
}


_INVERSE_OPERATIONS = {
    "__eq__": "__ne__",
    "__ne__": "__eq__",
    "__gt__": "__le__",
    "__lt__": "__ge__",
    "__ge__": "__lt__",
    "__le__": "__gt__",
    "ULT": "UGE",
    "UGE": "ULT",
    "UGT": "ULE",
    "ULE": "UGT",
    "SLT": "SGE",
    "SGE": "SLT",
    "SLE": "SGT",
    "SGT": "SLE",
}


#
# Util methods and mapping used during AIL AST to claripy AST conversion
#


def _op_with_unified_size(op, conv: Callable, operand0, operand1, ins_addr: int):
    # ensure operand1 is of the same size as operand0
    if isinstance(operand1, ailment.Expr.Const):
        # amazing - we do the easy thing here
        return op(conv(operand0, nobool=True, ins_addr=ins_addr), operand1.value)
    if operand1.bits == operand0.bits:
        return op(conv(operand0, nobool=True, ins_addr=ins_addr), conv(operand1, ins_addr=ins_addr))
    # extension is required
    assert operand1.bits < operand0.bits
    operand1 = ailment.Expr.Convert(None, operand1.bits, operand0.bits, False, operand1)
    return op(conv(operand0, nobool=True, ins_addr=ins_addr), conv(operand1, nobool=True, ins_addr=ins_addr))


def _dummy_bvs(condition, condition_mapping, name_suffix=""):
    var = claripy.BVS(f"ailexpr_{condition!r}{name_suffix}", condition.bits, explicit_name=True)
    condition_mapping[var.args[0]] = condition
    return var


def _dummy_bools(condition, condition_mapping, name_suffix=""):
    var = claripy.BoolS(f"ailexpr_{condition!r}{name_suffix}", explicit_name=True)
    condition_mapping[var.args[0]] = condition
    return var


_ail2claripy_op_mapping = {
    "LogicalAnd": lambda expr, conv, _, ia: claripy.And(
        conv(expr.operands[0], ins_addr=ia), conv(expr.operands[1], ins_addr=ia)
    ),
    "LogicalOr": lambda expr, conv, _, ia: claripy.Or(
        conv(expr.operands[0], ins_addr=ia), conv(expr.operands[1], ins_addr=ia)
    ),
    "CmpEQ": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    == conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpNE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    != conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    <= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLE (signed)": lambda expr, conv, _, ia: claripy.SLE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpLT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    < conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLT (signed)": lambda expr, conv, _, ia: claripy.SLT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpGE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    >= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpGE (signed)": lambda expr, conv, _, ia: claripy.SGE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpGT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    > conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpGT (signed)": lambda expr, conv, _, ia: claripy.SGT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpEQ": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    == conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpNE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    != conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    <= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLE (signed)": lambda expr, conv, _, ia: claripy.SLE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpLT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    < conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLT (signed)": lambda expr, conv, _, ia: claripy.SLT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpGE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    >= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpGE (signed)": lambda expr, conv, _, ia: claripy.SGE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpGT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    > conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpGT (signed)": lambda expr, conv, _, ia: claripy.SGT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "Add": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    + conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Sub": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    - conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Mul": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    * conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Div": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    / conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Mod": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    % conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Not": lambda expr, conv, _, ia: claripy.Not(conv(expr.operand, ins_addr=ia)),
    "Neg": lambda expr, conv, _, ia: -conv(expr.operand, ins_addr=ia),
    "BitwiseNeg": lambda expr, conv, _, ia: ~conv(expr.operand, ins_addr=ia),
    "Xor": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    ^ conv(expr.operands[1], nobool=True, ins_addr=ia),
    "And": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    & conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Or": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    | conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Shr": lambda expr, conv, _, ia: _op_with_unified_size(claripy.LShR, conv, expr.operands[0], expr.operands[1], ia),
    "Shl": lambda expr, conv, _, ia: _op_with_unified_size(
        operator.lshift, conv, expr.operands[0], expr.operands[1], ia
    ),
    "Sar": lambda expr, conv, _, ia: _op_with_unified_size(
        operator.rshift, conv, expr.operands[0], expr.operands[1], ia
    ),
    "Concat": lambda expr, conv, _, ia: claripy.Concat(*[conv(operand, ins_addr=ia) for operand in expr.operands]),
    # There are no corresponding claripy operations for the following operations
    "CmpF": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Mull": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Mull (signed)": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Reinterpret": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Rol": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Ror": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "LogicalXor": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Carry": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "SCarry": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "SBorrow": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ExpCmpNE": lambda expr, _, m, *args: _dummy_bools(expr, m),
    "CmpORD": lambda expr, _, m, *args: _dummy_bvs(expr, m),  # in case CmpORDRewriter fails
    "CmpEQV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "GetMSBs": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ShlNV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ShrNV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "InterleaveLOV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "InterleaveHIV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    # catch-all
    "_DUMMY_": lambda expr, _, m, *args: _dummy_bvs(expr, m),
}

#
# The ConditionProcessor class
#


[文档] class ConditionProcessor: """ Convert between claripy AST and AIL expressions. Also calculates reaching conditions of all nodes on a graph. """
[文档] def __init__(self, arch, condition_mapping=None): self.arch = arch self._condition_mapping: dict[str, Any] = {} if condition_mapping is None else condition_mapping self.jump_table_conds: dict[int, set] = defaultdict(set) self.edge_conditions = {} self.reaching_conditions = {} self.guarding_conditions = {} self._ast2annotations = {} self._peephole_expr_optimizations = [ cls(None, None, None) for cls in [InvertNegatedLogicalConjunctionsAndDisjunctions] ]
[文档] def clear(self): self._condition_mapping = {} self.jump_table_conds = defaultdict(set) self.reaching_conditions = {} self.guarding_conditions = {} self._ast2annotations = {}
[文档] def recover_edge_condition(self, graph: networkx.DiGraph, src, dst): edge = src, dst edge_data = graph.get_edge_data(*edge) edge_type = edge_data.get("type", "transition") if edge_data is not None else "transition" try: predicate = self._extract_predicate(src, dst, edge_type) except EmptyBlockNotice: # catch empty block notice - although this should not really happen predicate = claripy.true() return predicate
[文档] def recover_edge_conditions(self, region, graph=None) -> dict: edge_conditions = {} # traverse the graph to recover the condition for each edge graph = graph or region.graph for src in graph.nodes(): nodes = list(graph[src]) if len(nodes) >= 1: for dst in nodes: predicate = self.recover_edge_condition(graph, src, dst) edge_conditions[(src, dst)] = predicate self.edge_conditions = edge_conditions
[文档] def recover_reaching_conditions( self, region, graph=None, with_successors=False, case_entry_to_switch_head: dict[int, int] | None = None, simplify_conditions: bool = True, ): """ Recover the reaching conditions for each block in an acyclic graph. Note that we assume the graph that's passed in is acyclic. """ def _strictly_postdominates(inv_idoms, node_a, node_b): """ Does node A strictly post-dominate node B on the graph? """ return dominates(inv_idoms, node_a, node_b) self.recover_edge_conditions(region, graph=graph) edge_conditions = self.edge_conditions if graph: _g = graph head = next(node for node in graph.nodes if graph.in_degree(node) == 0) else: if with_successors and region.graph_with_successors is not None: _g = region.graph_with_successors else: _g = region.graph head = region.head # special handling for jump table entries - do not allow crossing between cases if case_entry_to_switch_head: _g = self._remove_crossing_edges_between_cases(_g, case_entry_to_switch_head) inverted_graph, idoms = inverted_idoms(_g) reaching_conditions = {} # recover the reaching condition for each node sorted_nodes = GraphUtils.quasi_topological_sort_nodes(_g) terminating_nodes = [] for node in sorted_nodes: # create special conditions for all nodes that are jump table entries if case_entry_to_switch_head and node.addr in case_entry_to_switch_head: jump_target_var = self.create_jump_target_var(case_entry_to_switch_head[node.addr]) cond = jump_target_var == claripy.BVV(node.addr, self.arch.bits) reaching_conditions[node] = cond self.jump_table_conds[case_entry_to_switch_head[node.addr]].add(cond) continue preds = _g.predecessors(node) reaching_condition = None out_degree = _g.out_degree(node) if out_degree == 0: terminating_nodes.append(node) if node is head: # the head is always reachable reaching_condition = claripy.true() elif idoms is not None and _strictly_postdominates(idoms, node, head): # the node that post dominates the head is always reachable reaching_conditions[node] = claripy.true() else: for pred in preds: edge = (pred, node) pred_condition = reaching_conditions.get(pred, claripy.true()) edge_condition = edge_conditions.get(edge, claripy.true()) if reaching_condition is None: reaching_condition = claripy.And(pred_condition, edge_condition) else: reaching_condition = claripy.Or(claripy.And(pred_condition, edge_condition), reaching_condition) if reaching_condition is not None: reaching_conditions[node] = ( self.simplify_condition(reaching_condition) if simplify_conditions else reaching_condition ) # My hypothesis: for nodes where two paths come together *and* those that cannot be further structured into # another if-else construct (we take the short-cut by testing if the operator is an "Or" after running our # condition simplifiers previously), we are better off using their "guarding conditions" instead of their # reaching conditions for if-else. see my super long chatlog with rhelmot on 5/14/2021. guarding_conditions = {} for the_node in sorted_nodes: preds = list(_g.predecessors(the_node)) if len(preds) != 2: continue # generate a graph slice that goes from the region head to this node slice_nodes = list(networkx.dfs_tree(inverted_graph, the_node)) subgraph = networkx.subgraph(_g, slice_nodes) # figure out which paths cause the divergence from this node nodes_do_not_reach_the_node = set() for node_ in subgraph: if node_ is the_node: continue for succ in _g.successors(node_): if not networkx.has_path(_g, succ, the_node): nodes_do_not_reach_the_node.add(succ) diverging_conditions = [] for node_ in nodes_do_not_reach_the_node: preds_ = list(_g.predecessors(node_)) for pred_ in preds_: if pred_ in nodes_do_not_reach_the_node: continue # this predecessor is the diverging node! edge_ = pred_, node_ edge_condition = edge_conditions.get(edge_, None) if edge_condition is not None: diverging_conditions.append(edge_condition) if diverging_conditions: # the negation of the union of diverging conditions is the guarding condition for this node cond = claripy.Or(*map(claripy.Not, diverging_conditions)) # pylint:disable=bad-builtin guarding_conditions[the_node] = cond self.reaching_conditions = reaching_conditions self.guarding_conditions = guarding_conditions
[文档] def remove_claripy_bool_asts(self, node, memo=None): # Convert claripy Bool ASTs to AIL expressions if memo is None: memo = {} if isinstance(node, SequenceNode): new_nodes = [] for n in node.nodes: new_node = self.remove_claripy_bool_asts(n, memo=memo) new_nodes.append(new_node) return SequenceNode(node.addr, new_nodes) if isinstance(node, MultiNode): new_nodes = [] for n in node.nodes: new_node = self.remove_claripy_bool_asts(n, memo=memo) new_nodes.append(new_node) return MultiNode(nodes=new_nodes) if isinstance(node, CodeNode): return CodeNode( self.remove_claripy_bool_asts(node.node, memo=memo), ( None if node.reaching_condition is None else self.convert_claripy_bool_ast(node.reaching_condition, memo=memo) ), ) if isinstance(node, ConditionalBreakNode): return ConditionalBreakNode( node.addr, self.convert_claripy_bool_ast(node.condition, memo=memo), node.target, ) if isinstance(node, ConditionNode): return ConditionNode( node.addr, ( None if node.reaching_condition is None else self.convert_claripy_bool_ast(node.reaching_condition, memo=memo) ), self.convert_claripy_bool_ast(node.condition, memo=memo), self.remove_claripy_bool_asts(node.true_node, memo=memo), self.remove_claripy_bool_asts(node.false_node, memo=memo), ) if isinstance(node, CascadingConditionNode): cond_and_nodes = [] for cond, child_node in node.condition_and_nodes: cond_and_nodes.append( ( self.convert_claripy_bool_ast(cond, memo=memo), self.remove_claripy_bool_asts(child_node, memo=memo), ) ) else_node = None if node.else_node is None else self.remove_claripy_bool_asts(node.else_node, memo=memo) return CascadingConditionNode( node.addr, cond_and_nodes, else_node=else_node, ) if isinstance(node, LoopNode): result = node.copy() result.condition = ( self.convert_claripy_bool_ast(node.condition, memo=memo) if node.condition is not None else None ) result.sequence_node = self.remove_claripy_bool_asts(node.sequence_node, memo=memo) return result if isinstance(node, SwitchCaseNode): return SwitchCaseNode( self.convert_claripy_bool_ast(node.switch_expr, memo=memo), OrderedDict( (idx, self.remove_claripy_bool_asts(case_node, memo=memo)) for idx, case_node in node.cases.items() ), self.remove_claripy_bool_asts(node.default_node, memo=memo), addr=node.addr, ) if isinstance(node, IncompleteSwitchCaseNode): return IncompleteSwitchCaseNode( node.addr, self.remove_claripy_bool_asts(node.head, memo=memo), [self.remove_claripy_bool_asts(case, memo=memo) for case in node.cases], ) return node
[文档] @classmethod def get_last_statement(cls, block): """ This is the buggy version of get_last_statements, because, you know, there can always be more than one last statement due to the existence of branching statements (like, If-then-else). All methods using get_last_statement() should switch to get_last_statements() and properly handle multiple last statements. """ if type(block) is SequenceNode: if block.nodes: return cls.get_last_statement(block.nodes[-1]) raise EmptyBlockNotice if type(block) is CodeNode: return cls.get_last_statement(block.node) if type(block) is ailment.Block: if not block.statements: raise EmptyBlockNotice return block.statements[-1] if type(block) is Block: raise NotImplementedError if type(block) is BlockNode: raise NotImplementedError if type(block) is MultiNode: # get the last node for the_block in reversed(block.nodes): try: return cls.get_last_statement(the_block) except EmptyBlockNotice: continue raise EmptyBlockNotice if type(block) is LoopNode: return cls.get_last_statement(block.sequence_node) if type(block) is ConditionalBreakNode: return None if type(block) is ConditionNode: s = None if block.true_node: try: s = cls.get_last_statement(block.true_node) except EmptyBlockNotice: s = None if s is None and block.false_node: s = cls.get_last_statement(block.false_node) return s if type(block) is CascadingConditionNode: s = None if block.else_node is not None: s = cls.get_last_statement(block.else_node) else: for _, node in reversed(block.condition_and_nodes): s = cls.get_last_statement(node) if s is not None: break return s if type(block) is BreakNode: return None if type(block) is ContinueNode: return None if type(block) is SwitchCaseNode: return None if type(block) is IncompleteSwitchCaseNode: return None if type(block) is GraphRegion: # normally this should not happen. however, we have test cases that trigger this case. return None raise NotImplementedError
[文档] @classmethod def get_last_statements(cls, block) -> list[ailment.Stmt.Statement | None]: if type(block) is SequenceNode: for last_node in reversed(block.nodes): try: return cls.get_last_statements(last_node) except EmptyBlockNotice: # the node is empty. try the next one continue raise EmptyBlockNotice if type(block) is CodeNode: return cls.get_last_statements(block.node) if type(block) is ailment.Block: if not block.statements: raise EmptyBlockNotice return [block.statements[-1]] if type(block) is Block: raise NotImplementedError if type(block) is BlockNode: raise NotImplementedError if type(block) is MultiNode: # get the last node for the_block in reversed(block.nodes): try: return cls.get_last_statements(the_block) except EmptyBlockNotice: continue raise EmptyBlockNotice if type(block) is LoopNode: if block.sequence_node is None: raise EmptyBlockNotice return cls.get_last_statements(block.sequence_node) if type(block) is ConditionalBreakNode: return [block] if type(block) is ConditionNode: s = [] if block.true_node: try: last_stmts = cls.get_last_statements(block.true_node) s.extend(last_stmts) except EmptyBlockNotice: pass else: s.append(None) if block.false_node: last_stmts = cls.get_last_statements(block.false_node) s.extend(last_stmts) else: s.append(None) return s if type(block) is CascadingConditionNode: s = [] if block.else_node is not None: try: last_stmts = cls.get_last_statements(block.else_node) s.extend(last_stmts) except EmptyBlockNotice: pass else: s.append(None) for _, node in block.condition_and_nodes: last_stmts = cls.get_last_statements(node) s.extend(last_stmts) return s if type(block) is BreakNode: return [block] if type(block) is ContinueNode: return [block] if type(block) is SwitchCaseNode: s = [] for case in block.cases.values(): s.extend(cls.get_last_statements(case)) if block.default_node is not None: s.extend(cls.get_last_statements(block.default_node)) else: s.append(None) return s if type(block) is IncompleteSwitchCaseNode: s = [] for case in block.cases: s.extend(cls.get_last_statements(case)) return s if type(block) is GraphRegion: # normally this should not happen. however, we have test cases that trigger this case. return [] raise NotImplementedError
# # Path predicate # EXC_COUNTER = 1000 def _extract_predicate(self, src_block, dst_block, edge_type) -> claripy.ast.Bool: if edge_type == "exception": # TODO: THIS IS ABSOLUTELY A HACK. AT THIS MOMENT YOU SHOULD NOT ATTEMPT TO MAKE SENSE OF EXCEPTION EDGES. self.EXC_COUNTER += 1 return self.claripy_ast_from_ail_condition( ailment.Expr.BinaryOp( None, "CmpEQ", ( ailment.Expr.Register(0x400000 + self.EXC_COUNTER, None, self.EXC_COUNTER, 64), ailment.Expr.Const(None, None, self.EXC_COUNTER, 64), ), False, ), ins_addr=dst_block.addr, ) if type(src_block) is ConditionalBreakNode: # at this point ConditionalBreakNode stores a claripy AST bool_var = src_block.condition if src_block.target == dst_block.addr: return bool_var return claripy.Not(bool_var) if type(src_block) is GraphRegion: return claripy.true() # sometimes the last statement is the conditional jump. sometimes it's the first statement of the block if ( isinstance(src_block, ailment.Block) and src_block.statements and isinstance(first_nonlabel_nonphi_statement(src_block), ailment.Stmt.ConditionalJump) ): last_stmt = first_nonlabel_nonphi_statement(src_block) else: last_stmt = self.get_last_statement(src_block) if last_stmt is None: return claripy.true() if type(last_stmt) is ailment.Stmt.Jump: if isinstance(last_stmt.target, ailment.Expr.Const): return claripy.true() # indirect jump target_ast = self.claripy_ast_from_ail_condition(last_stmt.target, ins_addr=last_stmt.ins_addr) return target_ast == dst_block.addr if type(last_stmt) is ailment.Stmt.ConditionalJump: bool_var = self.claripy_ast_from_ail_condition(last_stmt.condition, ins_addr=last_stmt.ins_addr) if isinstance(last_stmt.true_target, ailment.Expr.Const) and last_stmt.true_target.value == dst_block.addr: return bool_var return claripy.Not(bool_var) return claripy.true() # # Expression conversion # def _convert_extract(self, hi, lo, expr, tags, memo=None): # ailment does not support Extract. We translate Extract to Convert and shift. if lo == 0: return ailment.Expr.Convert( None, expr.size(), hi + 1, False, self.convert_claripy_bool_ast(expr, memo=memo), **tags, ) raise NotImplementedError("This case will be implemented once encountered.")
[文档] def convert_claripy_bool_ast(self, cond, memo=None): """ Convert recovered reaching conditions from claripy ASTs to ailment Expressions :return: None """ if not isinstance(cond, claripy.ast.Base): return cond if memo is None: memo = {} if cond._hash in memo: return memo[cond._hash] r = self.convert_claripy_bool_ast_core(cond, memo) optimized_r = peephole_optimize_expr(r, self._peephole_expr_optimizations) r = r if optimized_r is None else optimized_r memo[cond._hash] = r return r
[文档] def convert_claripy_bool_ast_core(self, cond, memo): if isinstance(cond, ailment.Expr.Expression): return cond if cond.op in {"BoolS", "BoolV"} and claripy.is_true(cond): return ailment.Expr.Const(None, None, True, 1) if cond in self._condition_mapping: return self._condition_mapping[cond] if cond.op in {"BVS", "BoolS"} and cond.args[0] in self._condition_mapping: return self._condition_mapping[cond.args[0]] def _binary_op_reduce(op, args, tags, signed=False): r = None for arg in args: if r is None: r = self.convert_claripy_bool_ast(arg, memo=memo) else: r = ailment.Expr.BinaryOp( None, op, (r, self.convert_claripy_bool_ast(arg, memo=memo)), signed, **tags ) return r def _unary_op_reduce(op, arg, tags): r = self.convert_claripy_bool_ast(arg, memo=memo) # TODO: Keep track of tags return ailment.Expr.UnaryOp(None, op, r, **tags) _mapping = { "Not": lambda cond_, tags: _unary_op_reduce("Not", cond_.args[0], tags), "__neg__": lambda cond_, tags: _unary_op_reduce("Not", cond_.args[0], tags), "__invert__": lambda cond_, tags: _unary_op_reduce("BitwiseNeg", cond_.args[0], tags), "And": lambda cond_, tags: _binary_op_reduce("LogicalAnd", cond_.args, tags), "Or": lambda cond_, tags: _binary_op_reduce("LogicalOr", cond_.args, tags), "__le__": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags, signed=True), "SLE": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags, signed=True), "__lt__": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags, signed=True), "SLT": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags, signed=True), "UGT": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags), "UGE": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags), "__gt__": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags, signed=True), "__ge__": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags, signed=True), "SGT": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags, signed=True), "SGE": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags, signed=True), "ULT": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags), "ULE": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags), "__eq__": lambda cond_, tags: _binary_op_reduce("CmpEQ", cond_.args, tags), "__ne__": lambda cond_, tags: _binary_op_reduce("CmpNE", cond_.args, tags), "__add__": lambda cond_, tags: _binary_op_reduce("Add", cond_.args, tags, signed=False), "__sub__": lambda cond_, tags: _binary_op_reduce("Sub", cond_.args, tags), "__mul__": lambda cond_, tags: _binary_op_reduce("Mul", cond_.args, tags), "__xor__": lambda cond_, tags: _binary_op_reduce("Xor", cond_.args, tags), "__or__": lambda cond_, tags: _binary_op_reduce("Or", cond_.args, tags, signed=False), "__and__": lambda cond_, tags: _binary_op_reduce("And", cond_.args, tags), "__lshift__": lambda cond_, tags: _binary_op_reduce("Shl", cond_.args, tags), "__rshift__": lambda cond_, tags: _binary_op_reduce("Sar", cond_.args, tags), "__floordiv__": lambda cond_, tags: _binary_op_reduce("Div", cond_.args, tags), "__mod__": lambda cond_, tags: _binary_op_reduce("Mod", cond_.args, tags), "LShR": lambda cond_, tags: _binary_op_reduce("Shr", cond_.args, tags), "BVV": lambda cond_, tags: ailment.Expr.Const(None, None, cond_.args[0], cond_.size(), **tags), "BoolV": lambda cond_, tags: ( ailment.Expr.Const(None, None, True, 1, **tags) if cond_.args[0] is True else ailment.Expr.Const(None, None, False, 1, **tags) ), "Extract": lambda cond_, tags: self._convert_extract(*cond_.args, tags, memo=memo), "ZeroExt": lambda cond_, tags: _binary_op_reduce( "Concat", [claripy.BVV(0, cond_.args[0]), cond_.args[1]], tags ), "Concat": lambda cond_, tags: _binary_op_reduce("Concat", cond_.args, tags), } if cond.op in _mapping: if cond in self._ast2annotations: cond_tags = self._ast2annotations.get(cond) elif claripy.Not(cond) in self._ast2annotations: cond_tags = self._ast2annotations.get(claripy.Not(cond)) else: cond_tags = {} return _mapping[cond.op](cond, cond_tags) raise NotImplementedError( f"Condition variable {cond} has an unsupported operator {cond.op}. Consider implementing." )
[文档] def claripy_ast_from_ail_condition( self, condition, nobool: bool = False, *, ins_addr: int = 0 ) -> claripy.ast.Bool | claripy.ast.Bits: # Unpack a condition all the way to the leaves if isinstance( condition, (claripy.ast.Bits, claripy.ast.Bool) ): # pylint:disable=isinstance-second-argument-not-valid-type return condition if isinstance( condition, (ailment.Expr.VEXCCallExpression, ailment.Expr.BasePointerOffset, ailment.Expr.ITE), ): return _dummy_bvs(condition, self._condition_mapping) if isinstance(condition, ailment.Stmt.Call): return _dummy_bvs(condition, self._condition_mapping, name_suffix=hex(condition.tags.get("ins_addr", 0))) if isinstance(condition, (ailment.Expr.Load, ailment.Expr.Register, ailment.Expr.VirtualVariable)): # does it have a variable associated? if condition.variable is not None: var = claripy.BVS( f"ailexpr_{condition!r}-{condition.variable.ident}-{ins_addr:x}", condition.bits, explicit_name=True, ) else: var = claripy.BVS( f"ailexpr_{condition!r}-{condition.idx}-{ins_addr:x}", condition.bits, explicit_name=True ) self._condition_mapping[var.args[0]] = condition return var if isinstance(condition, ailment.Expr.Convert): # convert is special. if it generates a 1-bit variable, it should be treated as a BoolS if condition.to_bits == 1: var_ = self.claripy_ast_from_ail_condition(condition.operands[0], ins_addr=ins_addr) name = f"ailcond_Conv({condition.from_bits}->{condition.to_bits}, {hash(var_)})" var = claripy.BoolS(name, explicit_name=True) else: var_ = self.claripy_ast_from_ail_condition(condition.operands[0], ins_addr=ins_addr) name = f"ailexpr_Conv({condition.from_bits}->{condition.to_bits}, {hash(var_)})" var = claripy.BVS(name, condition.to_bits, explicit_name=True) self._condition_mapping[var.args[0]] = condition return var if isinstance(condition, ailment.Expr.Const): if condition.value is True or condition.value is False: var = claripy.BoolV(condition.value) else: var = claripy.BVV(condition.value, condition.bits) if isinstance(var, claripy.ast.Bits) and var.size() == 1: var = claripy.true() if var.concrete_value == 1 else claripy.false() return var if isinstance(condition, ailment.Expr.Tmp): l.warning("Left-over ailment.Tmp variable %s.", condition) if condition.bits == 1: var = claripy.BoolS(f"ailtmp_{condition.tmp_idx}", explicit_name=True) else: var = claripy.BVS(f"ailtmp_{condition.tmp_idx}", condition.bits, explicit_name=True) self._condition_mapping[var.args[0]] = condition return var if isinstance(condition, ailment.Expr.MultiStatementExpression): # just cache it if condition.bits == 1: var = claripy.BoolS(f"mstmtexpr_{hash(condition)}", explicit_name=True) else: var = claripy.BVS(f"mstmtexpr_{hash(condition)}", condition.bits, explicit_name=True) self._condition_mapping[var.args[0]] = condition return var lambda_expr = _ail2claripy_op_mapping.get(condition.verbose_op, None) if lambda_expr is None: # fall back to op lambda_expr = _ail2claripy_op_mapping.get(condition.op, None) if lambda_expr is None: # fall back to the catch-all option l.debug( "Unsupported AIL expression operation %s (or verbose: %s). Fall back to the default catch-all dummy " "option. Consider implementing.", condition.op, condition.verbose_op, ) lambda_expr = _ail2claripy_op_mapping["_DUMMY_"] r = lambda_expr(condition, self.claripy_ast_from_ail_condition, self._condition_mapping, ins_addr) if isinstance(r, claripy.ast.Bool) and nobool: r = claripy.BVS(f"ailexpr_from_bool_{r!r}", 1, explicit_name=True) self._condition_mapping[r.args[0]] = condition if r is NotImplemented: if condition.bits == 1 and not nobool: r = claripy.BoolS(f"ailexpr_{condition!r}", explicit_name=True) else: r = claripy.BVS(f"ailexpr_{condition!r}", condition.bits, explicit_name=True) self._condition_mapping[r.args[0]] = condition # don't lose tags self._ast2annotations[r] = condition.tags return r
# # Expression simplification #
[文档] @staticmethod def claripy_ast_to_sympy_expr(ast, memo=None): if ast.op == "And": return sympy.And(*(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo) for arg in ast.args)) if ast.op == "Or": return sympy.Or(*(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo) for arg in ast.args)) if ast.op == "Not": return sympy.Not(ConditionProcessor.claripy_ast_to_sympy_expr(ast.args[0], memo=memo)) if ast.op in _UNIFIABLE_COMPARISONS: # unify comparisons to enable more simplification opportunities without going "deep" in sympy inverse_op = getattr(ast.args[0], _INVERSE_OPERATIONS[ast.op]) return sympy.Not(ConditionProcessor.claripy_ast_to_sympy_expr(inverse_op(ast.args[1]), memo=memo)) if memo is not None and ast in memo: return memo[ast] symbol = sympy.Symbol(str(hash(ast))) if memo is not None: memo[symbol] = ast return symbol
[文档] @staticmethod def sympy_expr_to_claripy_ast(expr, memo: dict): if expr.is_Symbol: return memo[expr] if isinstance(expr, sympy.Or): return claripy.Or(*(ConditionProcessor.sympy_expr_to_claripy_ast(arg, memo) for arg in expr.args)) if isinstance(expr, sympy.And): return claripy.And(*(ConditionProcessor.sympy_expr_to_claripy_ast(arg, memo) for arg in expr.args)) if isinstance(expr, sympy.Not): return claripy.Not(ConditionProcessor.sympy_expr_to_claripy_ast(expr.args[0], memo)) if isinstance(expr, sympy.logic.boolalg.BooleanTrue): return claripy.true() if isinstance(expr, sympy.logic.boolalg.BooleanFalse): return claripy.false() raise AngrRuntimeError("Unreachable reached")
[文档] @staticmethod def simplify_condition(cond, depth_limit=8, variables_limit=8): memo = {} if cond.depth > depth_limit or len(cond.variables) > variables_limit: return cond sympy_expr = ConditionProcessor.claripy_ast_to_sympy_expr(cond, memo=memo) return ConditionProcessor.sympy_expr_to_claripy_ast(sympy.simplify_logic(sympy_expr, deep=False), memo)
[文档] @staticmethod def simplify_condition_deprecated(cond): # Z3's simplification may yield weird and unreadable results # hence we mostly rely on our own simplification. we only use Z3's simplification results when it returns a # concrete value. claripy_simplified = claripy.simplify(cond) if not claripy_simplified.symbolic: return claripy_simplified simplified = ConditionProcessor._fold_double_negations(cond) cond = simplified if simplified is not None else cond simplified = ConditionProcessor._revert_short_circuit_conditions(cond) cond = simplified if simplified is not None else cond simplified = ConditionProcessor._extract_common_subexpressions(cond) cond = simplified if simplified is not None else cond # simplified = ConditionProcessor._remove_redundant_terms(cond) # cond = simplified if simplified is not None else cond # in the end, use claripy's simplification to handle really easy cases again simplified = ConditionProcessor._simplify_trivial_cases(cond) return simplified if simplified is not None else cond
@staticmethod def _simplify_trivial_cases(cond): if cond.op == "And": new_args = [] for arg in cond.args: claripy_simplified = claripy.simplify(arg) if claripy.is_true(claripy_simplified): continue new_args.append(arg) return claripy.And(*new_args) return None @staticmethod def _revert_short_circuit_conditions(cond): # revert short-circuit conditions # !A||(A&&!B) ==> !(A&&B) if cond.op != "Or": return cond if len(cond.args) == 1: # redundant operator. get rid of it return cond.args[0] or_arg0, or_arg1 = cond.args[:2] if or_arg1.op == "And": pass elif or_arg0.op == "And": or_arg0, or_arg1 = or_arg1, or_arg0 else: return cond not_a = or_arg0 solver = claripy.SolverCacheless() if not_a.variables == or_arg1.args[0].variables: solver.add(not_a == or_arg1.args[0]) not_b = or_arg1.args[1] elif not_a.variables == or_arg1.args[1].variables: solver.add(not_a == or_arg1.args[1]) not_b = or_arg1.args[0] else: return cond if not solver.satisfiable(): # found it! b = claripy.Not(not_b) a = claripy.Not(not_a) if len(cond.args) <= 2: return claripy.Not(claripy.And(a, b)) return claripy.Or(claripy.Not(claripy.And(a, b)), *cond.args[2:]) return cond @staticmethod def _fold_double_negations(cond): # !(!A) ==> A # !((!A) && (!B)) ==> A || B # !((!A) && B) ==> A || !B # !(A || B) ==> (!A && !B) if cond.op != "Not": return None if cond.args[0].op == "Not": return cond.args[0] if cond.args[0].op == "And" and len(cond.args[0].args) == 2: and_0, and_1 = cond.args[0].args if and_0.op == "Not" and and_1.op == "Not": return claripy.Or(and_0.args[0], and_1.args[0]) if and_0.op == "Not": # and_1.op != "Not" return claripy.Or(and_0.args[0], ConditionProcessor.simplify_condition(claripy.Not(and_1))) if cond.args[0].op == "Or" and len(cond.args[0].args) == 2: or_0, or_1 = cond.args[0].args return claripy.And( ConditionProcessor.simplify_condition(claripy.Not(or_0)), ConditionProcessor.simplify_condition(claripy.Not(or_1)), ) return None @staticmethod def _extract_common_subexpressions(cond): def _expr_inside_collection(expr_, coll_) -> bool: return any(expr_ is ex_ for ex_ in coll_) # (A && B) || (A && C) => A && (B || C) if cond.op == "And": args = [ConditionProcessor._extract_common_subexpressions(arg) for arg in cond.args] if all(arg is None for arg in args): return None return claripy.And(*((arg if arg is not None else ori_arg) for arg, ori_arg in zip(args, cond.args))) if cond.op == "Or": args = [ConditionProcessor._extract_common_subexpressions(arg) for arg in cond.args] args = [(arg if arg is not None else ori_arg) for arg, ori_arg in zip(args, cond.args)] expr_ctrs = defaultdict(int) for arg in args: if arg.op == "And": for subexpr in arg.args: expr_ctrs[subexpr] += 1 else: expr_ctrs[arg] += 1 common_exprs = [] for expr, ctr in expr_ctrs.items(): if ctr == len(args): # found a common one common_exprs.append(expr) if not common_exprs: return claripy.Or(*args) new_args = [] for arg in args: if arg.op == "And": new_subexprs = [ subexpr for subexpr in arg.args if not _expr_inside_collection(subexpr, common_exprs) ] new_args.append(claripy.And(*new_subexprs)) elif arg in common_exprs: continue else: raise AngrRuntimeError("Unexpected behavior - you should never reach here") return claripy.And(*common_exprs, claripy.Or(*new_args)) return None @staticmethod def _extract_terms(ast: claripy.ast.Bool) -> Generator[claripy.ast.Bool]: if ast.op == "And" or ast.op == "Or": for arg in ast.args: yield from ConditionProcessor._extract_terms(arg) elif ast.op == "Not": yield from ConditionProcessor._extract_terms(ast.args[0]) else: yield ast @staticmethod def _replace_term_in_ast( ast: claripy.ast.Bool, r0: claripy.ast.Bool, r0_with: claripy.ast.Bool, r1: claripy.ast.Bool, r1_with: claripy.ast.Bool, ) -> claripy.ast.Bool: if ast.op == "And": return claripy.And( *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args) ) if ast.op == "Or": return claripy.Or( *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args) ) if ast.op == "Not": return claripy.Not(ConditionProcessor._replace_term_in_ast(ast.args[0], r0, r0_with, r1, r1_with)) if ast is r0: return r0_with if ast is r1: return r1_with return ast @staticmethod def _remove_redundant_terms(cond): """ Extract all terms and test for each term if its truism impacts the truism of the entire condition. If not, the term is redundant and can be replaced with a True. """ all_terms = set() for term in ConditionProcessor._extract_terms(cond): if term not in all_terms: all_terms.add(term) negations = {} to_skip = set() all_terms_without_negs = set() for term in all_terms: if term in to_skip: continue neg = claripy.Not(term) if neg in all_terms: negations[term] = neg to_skip.add(neg) all_terms_without_negs.add(term) else: all_terms_without_negs.add(term) solver = claripy.SolverCacheless() for term in all_terms_without_negs: neg = negations.get(term) replaced_with_true = ConditionProcessor._replace_term_in_ast( cond, term, claripy.true(), neg, claripy.false() ) sat0 = solver.satisfiable( extra_constraints=( cond, claripy.Not(replaced_with_true), ) ) sat1 = solver.satisfiable( extra_constraints=( claripy.Not(cond), replaced_with_true, ) ) if sat0 or sat1: continue replaced_with_false = ConditionProcessor._replace_term_in_ast( cond, term, claripy.false(), neg, claripy.true() ) sat0 = solver.satisfiable( extra_constraints=( cond, claripy.Not(replaced_with_false), ) ) sat1 = solver.satisfiable( extra_constraints=( claripy.Not(cond), replaced_with_false, ) ) if sat0 or sat1: continue # TODO: Finish the implementation print(term, "is redundant") # # Graph processing # @staticmethod def _remove_crossing_edges_between_cases( graph: networkx.DiGraph, case_entry_to_switch_head: dict[int, int] ) -> networkx.DiGraph: starting_nodes = {node for node in graph if node.addr in case_entry_to_switch_head} if not starting_nodes: return graph traversed_nodes = set() edges_to_remove = set() for starting_node in starting_nodes: queue = [starting_node] while queue: src = queue.pop(0) traversed_nodes.add(src) successors = graph.successors(src) for succ in successors: if succ in traversed_nodes: # we should not traverse this node twice if graph.out_degree(succ) > 0: edges_to_remove.add((src, succ)) continue if succ in starting_nodes: # we do not want any jump from one node to a starting node edges_to_remove.add((src, succ)) continue traversed_nodes.add(src) queue.append(succ) if not edges_to_remove: return graph # make a copy before modifying the graph graph = networkx.DiGraph(graph) graph.remove_edges_from(edges_to_remove) return graph # # Utils #
[文档] def create_jump_target_var(self, jumptable_head_addr: int): return claripy.BVS(f"jump_table_{jumptable_head_addr:x}", self.arch.bits, explicit_name=True)