angr.analyses.decompiler.optimization_passes.ret_deduplicator 源代码

# pylint:disable=unnecessary-pass
from __future__ import annotations
import logging

from ailment import Block
from ailment.statement import ConditionalJump, Return

from angr.analyses.decompiler.structuring import SAILRStructurer, DreamStructurer
from angr.utils.graph import subgraph_between_nodes
from angr.analyses.decompiler.utils import remove_labels, to_ail_supergraph, update_labels
from .optimization_pass import OptimizationPass, OptimizationPassStage


_l = logging.getLogger(__name__)


[文档] class ReturnDeduplicator(OptimizationPass): """ Transforms: - if (cond) { ... return x; } return x; into: - if (cond) { ... } return x; TODO: its possible that this can be expanded to all rets that are equivalent. Testing needed. """ ARCHES = ["X86", "AMD64", "ARMEL", "ARMHF", "ARMCortexM", "MIPS32", "MIPS64"] PLATFORMS = ["windows", "linux", "cgc"] STAGE = OptimizationPassStage.DURING_REGION_IDENTIFICATION NAME = "Deduplicates return statements that may have been duplicated" DESCRIPTION = __doc__.strip() STRUCTURING = [SAILRStructurer.NAME, DreamStructurer.NAME]
[文档] def __init__(self, func, **kwargs): super().__init__(func, **kwargs) self.analyze()
def _check(self): return True, None def _analyze(self, cache=None): graph_updated = False if_ret_regions = self._find_if_ret_regions() for region_head, true_child, false_child, super_true, super_false in if_ret_regions: graph_updated |= self._fix_if_ret_region(region_head, true_child, false_child, super_true, super_false) if graph_updated: self.out_graph = update_labels(self._graph) def _fix_if_ret_region(self, region_head, true_child, false_child, super_true, super_false): """ A / \ C B => A / \ C B \\ / D The super blocks of the true and false child will be used as the replacement for the true and false child to assure correctness. """ if any(node not in self._graph for node in (region_head, true_child, false_child)): return False # destroy all but the head in the region for child in (true_child, false_child): region_nodes = subgraph_between_nodes(self._graph, region_head, [child], include_frontier=True) for node in region_nodes: if node is region_head: continue self._remove_block(node) # replace the head with a new if-stmt corrected block if_stmt = region_head.statements[-1] if_stmt.true_target.value = super_true.addr if_stmt.false_target.value = super_false.addr new_head = region_head.copy() new_head.statements[-1] = if_stmt # assures that preds still point to this block self._update_block(region_head, new_head) # create new children new_children = [] for child in (super_true, super_false): new_child = child.copy() new_child.statements = new_child.statements[:-1] new_children.append(new_child) # create a new return block true_ret: Return = true_child.statements[-1] false_ret: Return = false_child.statements[-1] ret_stmt = true_ret if true_ret.ins_addr > false_ret.ins_addr else false_ret # XXX: this size is wrong, but unknown how to fix ret_block = Block(self.new_block_addr(), 1, [ret_stmt]) # head -> [children] self._graph.add_edges_from([(new_head, new_children[0]), (new_head, new_children[1])]) # [children] -> ret self._graph.add_edges_from([(new_children[0], ret_block), (new_children[1], ret_block)]) return True def _find_if_ret_regions(self): """ We are looking for patterns in the graph that match the following schema: A: if (cond) goto B else goto C; B: ...; return x; C: ...; return x; """ # find all the if-stmt blocks in a graph with no single successor edges super_graph = to_ail_supergraph(remove_labels(self._graph)) if_stmt_blocks = [] for node in super_graph.nodes(): if not node.statements: continue if isinstance(node.statements[-1], ConditionalJump): if_stmt_blocks.append(node) if_ret_candidates = [] for if_stmt_block in if_stmt_blocks: if_stmt = if_stmt_block.statements[-1] children = list(super_graph.successors(if_stmt_block)) if len(children) != 2 or children[0] is children[1]: continue # find the true and false child of the if-stmt true_child, false_child = None, None for child in children: if child.addr == if_stmt.true_target.value: true_child = child elif child.addr == if_stmt.false_target.value: false_child = child # children must exist if ( true_child is None or false_child is None or true_child not in super_graph or false_child not in super_graph ): continue # they must also have some statements to work with if not true_child.statements or not false_child.statements: continue # equivalent returns true_stmt = true_child.statements[-1] false_stmt = false_child.statements[-1] if ( not isinstance(true_stmt, Return) or not isinstance(false_stmt, Return) or not true_stmt.likes(false_stmt) ): continue # both children must have only one predecessor if ( len(list(super_graph.predecessors(true_child))) != 1 or len(list(super_graph.predecessors(false_child))) != 1 ): continue if_ret_candidates.append((if_stmt_block, true_child, false_child)) return self._get_original_regions(if_ret_candidates) def _get_original_regions(self, if_ret_candidates: list[tuple[Block, Block, Block]]): """ Input: [(if_stmt_block, super_true_child, super_false_child), ...] Output: [(if_stmt_block, true_child, false_child, super_true_child, super_false_child), ...] super_* is the associated super block in the original graph """ # re-find all the blocks we intend to delete in the original graph ids = {} for blocks in if_ret_candidates: for block in blocks: if not block.statements: continue last_stmt = block.statements[-1] ids[(last_stmt.ins_addr, hash(last_stmt))] = block super_block_map = {} for block in self._graph.nodes(): if not block.statements: continue last_stmt = block.statements[-1] stmt_id = (last_stmt.ins_addr, hash(last_stmt)) if stmt_id in ids: super_block = ids[stmt_id] super_block_map[super_block] = block if_ret_regions = [] for super_blocks in if_ret_candidates: corrected_region = [] for super_block in super_blocks: block = super_block_map.get(super_block) if block is None: break corrected_region.append(block) else: # super_true corrected_region.append(super_blocks[1]) # super_false corrected_region.append(super_blocks[2]) # all blocks were found if_ret_regions.append(corrected_region) return if_ret_regions