diff --git a/diff.py b/diff.py index 0e1da09deb..8f5daeada2 100755 --- a/diff.py +++ b/diff.py @@ -1,10 +1,28 @@ #!/usr/bin/env python3 +# PYTHON_ARGCOMPLETE_OK +import argparse import sys +from typing import ( + Any, + Dict, + List, + Match, + NamedTuple, + NoReturn, + Optional, + Set, + Tuple, + Union, + Callable, + Pattern, +) -def fail(msg): + +def fail(msg: str) -> NoReturn: print(msg, file=sys.stderr) sys.exit(1) + # Prefer to use diff_settings.py from the current working directory sys.path.insert(0, ".") try: @@ -19,21 +37,25 @@ try: import argcomplete # type: ignore except ModuleNotFoundError: argcomplete = None -import argparse -parser = argparse.ArgumentParser(description="Diff MIPS assembly.") +parser = argparse.ArgumentParser(description="Diff MIPS or AArch64 assembly.") + +start_argument = parser.add_argument( + "start", + help="Function name or address to start diffing from.", +) -start_argument = parser.add_argument("start", help="Function name or address to start diffing from.") if argcomplete: - def complete_symbol(**kwargs): - prefix = kwargs["prefix"] - if prefix == "": + + def complete_symbol( + prefix: str, parsed_args: argparse.Namespace, **kwargs: object + ) -> List[str]: + if not prefix or prefix.startswith("-"): # skip reading the map file, which would # result in a lot of useless completions return [] - parsed_args = kwargs["parsed_args"] - config = {} - diff_settings.apply(config, parsed_args) + config: Dict[str, Any] = {} + diff_settings.apply(config, parsed_args) # type: ignore mapfile = config.get("mapfile") if not mapfile: return [] @@ -64,20 +86,28 @@ if argcomplete: pos = data.find(search, endPos) completes.append(match) return completes - start_argument.completer = complete_symbol -parser.add_argument("end", nargs="?", help="Address to end diff at.") + setattr(start_argument, "completer", complete_symbol) + +parser.add_argument( + "end", + nargs="?", + help="Address to end diff at.", +) parser.add_argument( "-o", dest="diff_obj", action="store_true", - help="Diff .o files rather than a whole binary. This makes it possible to see symbol names. (Recommended)", + help="Diff .o files rather than a whole binary. This makes it possible to " + "see symbol names. (Recommended)", ) parser.add_argument( "-e", "--elf", dest="diff_elf_symbol", - help="Diff a given function in two ELFs, one being stripped and the other one non-stripped. Requires objdump from binutils 2.33+.", + metavar="SYMBOL", + help="Diff a given function in two ELFs, one being stripped and the other " + "one non-stripped. Requires objdump from binutils 2.33+.", ) parser.add_argument( "--source", @@ -114,6 +144,7 @@ parser.add_argument( dest="skip_lines", type=int, default=0, + metavar="LINES", help="Skip the first N lines of output.", ) parser.add_argument( @@ -130,6 +161,12 @@ parser.add_argument( action="store_true", help="Pretend all large enough immediates are the same.", ) +parser.add_argument( + "-I", + "--ignore-addr-diffs", + action="store_true", + help="Ignore address differences. Currently only affects AArch64.", +) parser.add_argument( "-B", "--no-show-branches", @@ -158,12 +195,22 @@ parser.add_argument( ) parser.add_argument( "-3", - "--threeway", + "--threeway=prev", dest="threeway", - action="store_true", + action="store_const", + const="prev", help="Show a three-way diff between target asm, current asm, and asm " "prior to -w rebuild. Requires -w.", ) +parser.add_argument( + "-b", + "--threeway=base", + dest="threeway", + action="store_const", + const="base", + help="Show a three-way diff between target asm, current asm, and asm " + "when diff.py was started. Requires -w.", +) parser.add_argument( "--width", dest="column_width", @@ -176,7 +223,8 @@ parser.add_argument( dest="algorithm", default="levenshtein", choices=["levenshtein", "difflib"], - help="Diff algorithm to use.", + help="Diff algorithm to use. Levenshtein gives the minimum diff, while difflib " + "aims for long sections of equal opcodes. Defaults to %(default)s.", ) parser.add_argument( "--max-size", @@ -188,14 +236,17 @@ parser.add_argument( ) # Project-specific flags, e.g. different versions/make arguments. -if hasattr(diff_settings, "add_custom_arguments"): - diff_settings.add_custom_arguments(parser) # type: ignore +add_custom_arguments_fn = getattr(diff_settings, "add_custom_arguments", None) +if add_custom_arguments_fn: + add_custom_arguments_fn(parser) if argcomplete: argcomplete.autocomplete(parser) # ==== IMPORTS ==== +# (We do imports late to optimize auto-complete performance.) + import re import os import ast @@ -206,7 +257,6 @@ import itertools import threading import queue import time -from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union MISSING_PREREQUISITES = ( @@ -227,20 +277,22 @@ args = parser.parse_args() # Set imgs, map file and make flags in a project-specific manner. config: Dict[str, Any] = {} -diff_settings.apply(config, args) +diff_settings.apply(config, args) # type: ignore -arch = config.get("arch", "mips") -baseimg = config.get("baseimg", None) -myimg = config.get("myimg", None) -mapfile = config.get("mapfile", None) -makeflags = config.get("makeflags", []) -source_directories = config.get("source_directories", None) -objdump_executable = config.get("objdump_executable", None) +arch: str = config.get("arch", "mips") +baseimg: Optional[str] = config.get("baseimg") +myimg: Optional[str] = config.get("myimg") +mapfile: Optional[str] = config.get("mapfile") +makeflags: List[str] = config.get("makeflags", []) +source_directories: Optional[List[str]] = config.get("source_directories") +objdump_executable: Optional[str] = config.get("objdump_executable") +map_format: str = config.get("map_format", "gnu") +mw_build_dir: str = config.get("mw_build_dir", "build/") -MAX_FUNCTION_SIZE_LINES = args.max_lines -MAX_FUNCTION_SIZE_BYTES = MAX_FUNCTION_SIZE_LINES * 4 +MAX_FUNCTION_SIZE_LINES: int = args.max_lines +MAX_FUNCTION_SIZE_BYTES: int = MAX_FUNCTION_SIZE_LINES * 4 -COLOR_ROTATION = [ +COLOR_ROTATION: List[str] = [ Fore.MAGENTA, Fore.CYAN, Fore.GREEN, @@ -252,14 +304,16 @@ COLOR_ROTATION = [ Fore.LIGHTBLACK_EX, ] -BUFFER_CMD = ["tail", "-c", str(10 ** 9)] -LESS_CMD = ["less", "-SRic", "-#6"] +BUFFER_CMD: List[str] = ["tail", "-c", str(10 ** 9)] +LESS_CMD: List[str] = ["less", "-SRic", "-#6"] -DEBOUNCE_DELAY = 0.1 -FS_WATCH_EXTENSIONS = [".c", ".h"] +DEBOUNCE_DELAY: float = 0.1 +FS_WATCH_EXTENSIONS: List[str] = [".c", ".h"] # ==== LOGIC ==== +ObjdumpCommand = Tuple[List[str], str, Optional[str]] + if args.algorithm == "levenshtein": try: import Levenshtein # type: ignore @@ -272,6 +326,9 @@ if args.source: except ModuleNotFoundError as e: fail(MISSING_PREREQUISITES.format(e.name)) +if args.threeway and not args.watch: + fail("Threeway diffing requires -w.") + if objdump_executable is None: for objdump_cand in ["mips-linux-gnu-objdump", "mips64-elf-objdump"]: try: @@ -293,35 +350,41 @@ if not objdump_executable: ) -def eval_int(expr, emsg=None): +def maybe_eval_int(expr: str) -> Optional[int]: try: ret = ast.literal_eval(expr) if not isinstance(ret, int): raise Exception("not an integer") return ret except Exception: - if emsg is not None: - fail(emsg) return None -def eval_line_num(expr): +def eval_int(expr: str, emsg: str) -> int: + ret = maybe_eval_int(expr) + if ret is None: + fail(emsg) + return ret + + +def eval_line_num(expr: str) -> int: return int(expr.strip().replace(":", ""), 16) -def run_make(target, capture_output=False): - if capture_output: - return subprocess.run( - ["make"] + makeflags + [target], - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - else: - subprocess.check_call(["make"] + makeflags + [target]) +def run_make(target: str) -> None: + subprocess.check_call(["make"] + makeflags + [target]) -def restrict_to_function(dump, fn_name): - out = [] +def run_make_capture_output(target: str) -> "subprocess.CompletedProcess[bytes]": + return subprocess.run( + ["make"] + makeflags + [target], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + +def restrict_to_function(dump: str, fn_name: str) -> str: + out: List[str] = [] search = f"<{fn_name}>:" found = False for line in dump.split("\n"): @@ -334,13 +397,13 @@ def restrict_to_function(dump, fn_name): return "\n".join(out) -def maybe_get_objdump_source_flags(): +def maybe_get_objdump_source_flags() -> List[str]: if not args.source: return [] flags = [ "--source", - "--source-comment=| ", + "--source-comment=│ ", "-l", ] @@ -350,8 +413,9 @@ def maybe_get_objdump_source_flags(): return flags -def run_objdump(cmd): +def run_objdump(cmd: ObjdumpCommand) -> str: flags, target, restrict = cmd + assert objdump_executable, "checked previously" out = subprocess.check_output( [objdump_executable] + arch_flags + flags + [target], universal_newlines=True ) @@ -360,53 +424,76 @@ def run_objdump(cmd): return out -base_shift = eval_int( +base_shift: int = eval_int( args.base_shift, "Failed to parse --base-shift (-S) argument as an integer." ) -def search_map_file(fn_name): +def search_map_file(fn_name: str) -> Tuple[Optional[str], Optional[int]]: if not mapfile: fail(f"No map file configured; cannot find function {fn_name}.") try: with open(mapfile) as f: - lines = f.read().split("\n") + contents = f.read() except Exception: fail(f"Failed to open map file {mapfile} for reading.") - try: - cur_objfile = None - ram_to_rom = None - cands = [] - last_line = "" - for line in lines: - if line.startswith(" .text"): - cur_objfile = line.split()[3] - if "load address" in line: - tokens = last_line.split() + line.split() - ram = int(tokens[1], 0) - rom = int(tokens[5], 0) - ram_to_rom = rom - ram - if line.endswith(" " + fn_name): - ram = int(line.split()[0], 0) - if cur_objfile is not None and ram_to_rom is not None: - cands.append((cur_objfile, ram + ram_to_rom)) - last_line = line - except Exception as e: - import traceback + if map_format == 'gnu': + lines = contents.split("\n") - traceback.print_exc() - fail(f"Internal error while parsing map file") + try: + cur_objfile = None + ram_to_rom = None + cands = [] + last_line = "" + for line in lines: + if line.startswith(" .text"): + cur_objfile = line.split()[3] + if "load address" in line: + tokens = last_line.split() + line.split() + ram = int(tokens[1], 0) + rom = int(tokens[5], 0) + ram_to_rom = rom - ram + if line.endswith(" " + fn_name): + ram = int(line.split()[0], 0) + if cur_objfile is not None and ram_to_rom is not None: + cands.append((cur_objfile, ram + ram_to_rom)) + last_line = line + except Exception as e: + import traceback - if len(cands) > 1: - fail(f"Found multiple occurrences of function {fn_name} in map file.") - if len(cands) == 1: - return cands[0] + traceback.print_exc() + fail(f"Internal error while parsing map file") + + if len(cands) > 1: + fail(f"Found multiple occurrences of function {fn_name} in map file.") + if len(cands) == 1: + return cands[0] + elif map_format == 'mw': + # ram elf rom object name + find = re.findall(re.compile(r' \S+ \S+ (\S+) (\S+) . ' + fn_name + r'(?: \(entry of \.(?:init|text)\))? \t(\S+)'), contents) + if len(find) > 1: + fail(f"Found multiple occurrences of function {fn_name} in map file.") + if len(find) == 1: + rom = int(find[0][1],16) + objname = find[0][2] + # The metrowerks linker map format does not contain the full object path, so we must complete it manually. + objfiles = [os.path.join(dirpath, f) for dirpath, _, filenames in os.walk(mw_build_dir) for f in filenames if f == objname] + if len(objfiles) > 1: + all_objects = "\n".join(objfiles) + fail(f"Found multiple objects of the same name {objname} in {mw_build_dir}, cannot determine which to diff against: \n{all_objects}") + if len(objfiles) == 1: + objfile = objfiles[0] + # TODO Currently the ram-rom conversion only works for diffing ELF executables, but it would likely be more convenient to diff DOLs. + # At this time it is recommended to always use -o when running the diff script as this mode does not make use of the ram-rom conversion + return objfile, rom + else: + fail(f"Linker map format {map_format} unrecognised.") return None, None -def dump_elf(): +def dump_elf() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]: if not baseimg or not myimg: fail("Missing myimg/baseimg in config.") if base_shift: @@ -436,7 +523,7 @@ def dump_elf(): ) -def dump_objfile(): +def dump_objfile() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]: if base_shift: fail("--base-shift not compatible with -o") if args.end is not None: @@ -466,12 +553,12 @@ def dump_objfile(): ) -def dump_binary(): +def dump_binary() -> Tuple[str, ObjdumpCommand, ObjdumpCommand]: if not baseimg or not myimg: fail("Missing myimg/baseimg in config.") if args.make: run_make(myimg) - start_addr = eval_int(args.start) + start_addr = maybe_eval_int(args.start) if start_addr is None: _, start_addr = search_map_file(args.start) if start_addr is None: @@ -480,7 +567,7 @@ def dump_binary(): end_addr = eval_int(args.end, "End address must be an integer expression.") else: end_addr = start_addr + MAX_FUNCTION_SIZE_BYTES - objdump_flags = ["-Dz", "-bbinary", "-mmips", "-EB"] + objdump_flags = ["-Dz", "-bbinary", "-EB"] flags1 = [ f"--start-address={start_addr + base_shift}", f"--stop-address={end_addr + base_shift}", @@ -493,9 +580,9 @@ def dump_binary(): ) -# Alignment with ANSI colors is broken, let's fix it. -def ansi_ljust(s, width): - needed = width - ansiwrap.ansilen(s) +def ansi_ljust(s: str, width: int) -> str: + """Like s.ljust(width), but accounting for ANSI colors.""" + needed: int = width - ansiwrap.ansilen(s) if needed > 0: return s + " " * needed else: @@ -505,7 +592,9 @@ def ansi_ljust(s, width): if arch == "mips": re_int = re.compile(r"[0-9]+") re_comment = re.compile(r"<.*?>") - re_reg = re.compile(r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra)\b") + re_reg = re.compile( + r"\$?\b(a[0-3]|t[0-9]|s[0-8]|at|v[01]|f[12]?[0-9]|f3[01]|k[01]|fp|ra|zero)\b" + ) re_sprel = re.compile(r"(?<=,)([0-9]+|0x[0-9a-f]+)\(sp\)") re_large_imm = re.compile(r"-?[1-9][0-9]{2,}|-?0x[0-9a-f]{3,}") re_imm = re.compile(r"(\b|-)([0-9]+|0x[0-9a-fA-F]+)\b(?!\(sp)|%(lo|hi)\([^)]*\)") @@ -524,7 +613,19 @@ if arch == "mips": "bc1fl", } branch_instructions = branch_likely_instructions.union( - {"b", "beq", "bne", "beqz", "bnez", "bgez", "bgtz", "blez", "bltz", "bc1t", "bc1f"} + { + "b", + "beq", + "bne", + "beqz", + "bnez", + "bgez", + "bgtz", + "blez", + "bltz", + "bc1t", + "bc1f", + } ) instructions_with_address_immediates = branch_instructions.union({"jal", "j"}) elif arch == "aarch64": @@ -539,13 +640,71 @@ elif arch == "aarch64": arch_flags = [] forbidden = set(string.ascii_letters + "_") branch_likely_instructions = set() - branch_instructions = {"bl", "b", "b.eq", "b.ne", "b.cs", "b.hs", "b.cc", "b.lo", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "cbz", "cbnz", "tbz", "tbnz"} + branch_instructions = { + "bl", + "b", + "b.eq", + "b.ne", + "b.cs", + "b.hs", + "b.cc", + "b.lo", + "b.mi", + "b.pl", + "b.vs", + "b.vc", + "b.hi", + "b.ls", + "b.ge", + "b.lt", + "b.gt", + "b.le", + "cbz", + "cbnz", + "tbz", + "tbnz", + } instructions_with_address_immediates = branch_instructions.union({"adrp"}) +elif arch == "ppc": + re_int = re.compile(r"[0-9]+") + re_comment = re.compile(r"(<.*?>|//.*$)") + re_reg = re.compile(r"\$?\b([rf][0-9]+)\b") + re_sprel = re.compile(r"(?<=,)(-?[0-9]+|-?0x[0-9a-f]+)\(r1\)") + re_large_imm = re.compile(r"-?[1-9][0-9]{2,}|-?0x[0-9a-f]{3,}") + re_imm = re.compile(r"(\b|-)([0-9]+|0x[0-9a-fA-F]+)\b(?!\(r1)|[^@]*@(ha|h|lo)") + arch_flags = [] + forbidden = set(string.ascii_letters + "_") + branch_likely_instructions = set() + branch_instructions = { + "b", + "beq", + "beq+", + "beq-", + "bne", + "bne+", + "bne-", + "blt", + "blt+", + "blt-", + "ble", + "ble+", + "ble-", + "bdnz", + "bdnz+", + "bdnz-", + "bge", + "bge+", + "bge-", + "bgt", + "bgt+", + "bgt-", + } + instructions_with_address_immediates = branch_instructions.union({"bl"}) else: - fail("Unknown architecture.") + fail(f"Unknown architecture: {arch}") -def hexify_int(row, pat): +def hexify_int(row: str, pat: Match[str]) -> str: full = pat.group(0) if len(full) <= 1: # leave one-digit ints alone @@ -558,11 +717,14 @@ def hexify_int(row, pat): return hex(int(full)) -def parse_relocated_line(line): +def parse_relocated_line(line: str) -> Tuple[str, str, str]: try: ind2 = line.rindex(",") except ValueError: - ind2 = line.rindex("\t") + try: + ind2 = line.rindex("\t") + except ValueError: + ind2 = line.rindex(" ") before = line[: ind2 + 1] after = line[ind2 + 1 :] ind2 = after.find("(") @@ -575,7 +737,7 @@ def parse_relocated_line(line): return before, imm, after -def process_mips_reloc(row, prev): +def process_mips_reloc(row: str, prev: str) -> str: before, imm, after = parse_relocated_line(prev) repl = row.split()[-1] if imm != "0": @@ -596,12 +758,48 @@ def process_mips_reloc(row, prev): # correct addend for each, but objdump doesn't give us the order of # the relocations, so we can't find the right LO16. :( repl = f"%hi({repl})" + elif "R_MIPS_26" in row: + # Function calls + pass + elif "R_MIPS_PC16" in row: + # Branch to glabel. This gives confusing output, but there's not much + # we can do here. + pass else: - assert "R_MIPS_26" in row, f"unknown relocation type '{row}'" + assert False, f"unknown relocation type '{row}' for line '{prev}'" return before + repl + after -def pad_mnemonic(line): +def process_ppc_reloc(row: str, prev: str) -> str: + assert any(r in row for r in ["R_PPC_REL24", "R_PPC_ADDR16", "R_PPC_EMB_SDA21"]), f"unknown relocation type '{row}' for line '{prev}'" + before, imm, after = parse_relocated_line(prev) + repl = row.split()[-1] + if "R_PPC_REL24" in row: + # function calls + pass + elif "R_PPC_ADDR16_HI" in row: + # absolute hi of addr + repl = f"{repl}@h" + elif "R_PPC_ADDR16_HA" in row: + # adjusted hi of addr + repl = f"{repl}@ha" + elif "R_PPC_ADDR16_LO" in row: + # lo of addr + repl = f"{repl}@l" + elif "R_PPC_ADDR16" in row: + # 16-bit absolute addr + if "+0x7" in repl: + # remove the very large addends as they are an artifact of (label-_SDA(2)_BASE_) + # computations and are unimportant in a diff setting. + if int(repl.split("+")[1],16) > 0x70000000: + repl = repl.split("+")[0] + elif "R_PPC_EMB_SDA21" in row: + # small data area + pass + return before + repl + after + + +def pad_mnemonic(line: str) -> str: if "\t" not in line: return line mn, args = line.split("\t", 1) @@ -612,13 +810,82 @@ class Line(NamedTuple): mnemonic: str diff_row: str original: str + normalized_original: str line_num: str branch_target: Optional[str] source_lines: List[str] comment: Optional[str] -def process(lines): +class DifferenceNormalizer: + def normalize(self, mnemonic: str, row: str) -> str: + """This should be called exactly once for each line.""" + row = self._normalize_arch_specific(mnemonic, row) + if args.ignore_large_imms: + row = re.sub(re_large_imm, "", row) + return row + + def _normalize_arch_specific(self, mnemonic: str, row: str) -> str: + return row + + +class DifferenceNormalizerAArch64(DifferenceNormalizer): + def __init__(self) -> None: + super().__init__() + self._adrp_pair_registers: Set[str] = set() + + def _normalize_arch_specific(self, mnemonic: str, row: str) -> str: + if args.ignore_addr_diffs: + row = self._normalize_adrp_differences(mnemonic, row) + row = self._normalize_bl(mnemonic, row) + return row + + def _normalize_bl(self, mnemonic: str, row: str) -> str: + if mnemonic != "bl": + return row + + row, _ = split_off_branch(row) + return row + + def _normalize_adrp_differences(self, mnemonic: str, row: str) -> str: + """Identifies ADRP + LDR/ADD pairs that are used to access the GOT and + suppresses any immediate differences. + + Whenever an ADRP is seen, the destination register is added to the set of registers + that are part of an ADRP + LDR/ADD pair. Registers are removed from the set as soon + as they are used for an LDR or ADD instruction which completes the pair. + + This method is somewhat crude but should manage to detect most such pairs. + """ + row_parts = row.split("\t", 1) + if mnemonic == "adrp": + self._adrp_pair_registers.add(row_parts[1].strip().split(",")[0]) + row, _ = split_off_branch(row) + elif mnemonic == "ldr": + for reg in self._adrp_pair_registers: + # ldr xxx, [reg] + # ldr xxx, [reg, ] + if f", [{reg}" in row_parts[1]: + self._adrp_pair_registers.remove(reg) + return normalize_imms(row) + elif mnemonic == "add": + for reg in self._adrp_pair_registers: + # add reg, reg, + if row_parts[1].startswith(f"{reg}, {reg}, "): + self._adrp_pair_registers.remove(reg) + return normalize_imms(row) + + return row + + +def make_difference_normalizer() -> DifferenceNormalizer: + if arch == "aarch64": + return DifferenceNormalizerAArch64() + return DifferenceNormalizer() + + +def process(lines: List[str]) -> List[Line]: + normalizer = make_difference_normalizer() skip_next = False source_lines = [] if not args.diff_obj: @@ -626,7 +893,7 @@ def process(lines): if lines and not lines[-1]: lines.pop() - output = [] + output: List[Line] = [] stop_after_delay_slot = False for row in lines: if args.diff_obj and (">:" in row or not row): @@ -648,6 +915,11 @@ def process(lines): output[-1] = output[-1]._replace(original=new_original) continue + if "R_PPC_" in row: + new_original = process_ppc_reloc(row, output[-1].original) + output[-1] = output[-1]._replace(original=new_original) + continue + m_comment = re.search(re_comment, row) comment = m_comment[0] if m_comment else None row = re.sub(re_comment, "", row) @@ -655,11 +927,18 @@ def process(lines): tabs = row.split("\t") row = "\t".join(tabs[2:]) line_num = tabs[0].strip() - row_parts = row.split("\t", 1) + + if "\t" in row: + row_parts = row.split("\t", 1) + else: + # powerpc-eabi-objdump doesn't use tabs + row_parts = [part.lstrip() for part in row.split(" ", 1)] mnemonic = row_parts[0].strip() + if mnemonic not in instructions_with_address_immediates: - row = re.sub(re_int, lambda s: hexify_int(row, s), row) + row = re.sub(re_int, lambda m: hexify_int(row, m), row) original = row + normalized_original = normalizer.normalize(mnemonic, original) if skip_next: skip_next = False row = "" @@ -688,6 +967,7 @@ def process(lines): mnemonic=mnemonic, diff_row=row, original=original, + normalized_original=normalized_original, line_num=line_num, branch_target=branch_target, source_lines=source_lines, @@ -704,16 +984,18 @@ def process(lines): return output -def format_single_line_diff(line1, line2, column_width): - return f"{ansi_ljust(line1,column_width)}{line2}" +def format_single_line_diff(line1: str, line2: str, column_width: int) -> str: + return ansi_ljust(line1, column_width) + line2 class SymbolColorer: - def __init__(self, base_index): + symbol_colors: Dict[str, str] + + def __init__(self, base_index: int) -> None: self.color_index = base_index self.symbol_colors = {} - def color_symbol(self, s, t=None): + def color_symbol(self, s: str, t: Optional[str] = None) -> str: try: color = self.symbol_colors[s] except: @@ -724,59 +1006,54 @@ class SymbolColorer: return f"{color}{t}{Fore.RESET}" -def maybe_normalize_large_imms(row): - if args.ignore_large_imms: - row = re.sub(re_large_imm, "", row) - return row - - -def normalize_imms(row): +def normalize_imms(row: str) -> str: return re.sub(re_imm, "", row) -def normalize_stack(row): +def normalize_stack(row: str) -> str: return re.sub(re_sprel, "addr(sp)", row) -def split_off_branch(line): +def split_off_branch(line: str) -> Tuple[str, str]: parts = line.split(",") if len(parts) < 2: parts = line.split(None, 1) off = len(line) - len(parts[-1]) return line[:off], line[off:] +ColorFunction = Callable[[str], str] -def color_imms(out1, out2): - g1 = [] - g2 = [] - re.sub(re_imm, lambda s: g1.append(s.group()), out1) - re.sub(re_imm, lambda s: g2.append(s.group()), out2) - if len(g1) == len(g2): - diffs = [x != y for (x, y) in zip(g1, g2)] - it = iter(diffs) +def color_fields(pat: Pattern[str], out1: str, out2: str, color1: ColorFunction, color2: Optional[ColorFunction]=None) -> Tuple[str, str]: + diffs = [of.group() != nf.group() for (of, nf) in zip(pat.finditer(out1), pat.finditer(out2))] - def maybe_color(s): - return f"{Fore.LIGHTBLUE_EX}{s}{Style.RESET_ALL}" if next(it) else s + it = iter(diffs) + def maybe_color(color: ColorFunction, s: str) -> str: + return color(s) if next(it, False) else f"{Style.RESET_ALL}{s}" + + out1 = pat.sub(lambda m: maybe_color(color1, m.group()), out1) + it = iter(diffs) + out2 = pat.sub(lambda m: maybe_color(color2 or color1, m.group()), out2) - out1 = re.sub(re_imm, lambda s: maybe_color(s.group()), out1) - it = iter(diffs) - out2 = re.sub(re_imm, lambda s: maybe_color(s.group()), out2) return out1, out2 -def color_branch_imms(br1, br2): +def color_branch_imms(br1: str, br2: str) -> Tuple[str, str]: if br1 != br2: br1 = f"{Fore.LIGHTBLUE_EX}{br1}{Style.RESET_ALL}" br2 = f"{Fore.LIGHTBLUE_EX}{br2}{Style.RESET_ALL}" return br1, br2 -def diff_sequences_difflib(seq1, seq2): +def diff_sequences_difflib( + seq1: List[str], seq2: List[str] +) -> List[Tuple[str, int, int, int, int]]: differ = difflib.SequenceMatcher(a=seq1, b=seq2, autojunk=False) return differ.get_opcodes() -def diff_sequences(seq1, seq2): +def diff_sequences( + seq1: List[str], seq2: List[str] +) -> List[Tuple[str, int, int, int, int]]: if ( args.algorithm != "levenshtein" or len(seq1) * len(seq2) > 4 * 10 ** 8 @@ -786,9 +1063,9 @@ def diff_sequences(seq1, seq2): # The Levenshtein library assumes that we compare strings, not lists. Convert. # (Per the check above we know we have fewer than 0x110000 unique elements, so chr() works.) - remapping = {} + remapping: Dict[str, str] = {} - def remap(seq): + def remap(seq: List[str]) -> str: seq = seq[:] for i in range(len(seq)): val = remapping.get(seq[i]) @@ -798,17 +1075,41 @@ def diff_sequences(seq1, seq2): seq[i] = val return "".join(seq) - seq1 = remap(seq1) - seq2 = remap(seq2) - return Levenshtein.opcodes(seq1, seq2) + rem1 = remap(seq1) + rem2 = remap(seq2) + return Levenshtein.opcodes(rem1, rem2) # type: ignore + + +def diff_lines( + lines1: List[Line], + lines2: List[Line], +) -> List[Tuple[Optional[Line], Optional[Line]]]: + ret = [] + for (tag, i1, i2, j1, j2) in diff_sequences( + [line.mnemonic for line in lines1], + [line.mnemonic for line in lines2], + ): + for line1, line2 in itertools.zip_longest(lines1[i1:i2], lines2[j1:j2]): + if tag == "replace": + if line1 is None: + tag = "insert" + elif line2 is None: + tag = "delete" + elif tag == "insert": + assert line1 is None + elif tag == "delete": + assert line2 is None + ret.append((line1, line2)) + + return ret class OutputLine: base: Optional[str] fmt2: str - key2: str + key2: Optional[str] - def __init__(self, base: Optional[str], fmt2: str, key2: str) -> None: + def __init__(self, base: Optional[str], fmt2: str, key2: Optional[str]) -> None: self.base = base self.fmt2 = fmt2 self.key2 = key2 @@ -848,141 +1149,128 @@ def do_diff(basedump: str, mydump: str) -> List[OutputLine]: btset.add(bt + ":") sc.color_symbol(bt + ":") - for (tag, i1, i2, j1, j2) in diff_sequences( - [line.mnemonic for line in lines1], [line.mnemonic for line in lines2] - ): - for line1, line2 in itertools.zip_longest(lines1[i1:i2], lines2[j1:j2]): - if tag == "replace": - if line1 is None: - tag = "insert" - elif line2 is None: - tag = "delete" - elif tag == "insert": - assert line1 is None - elif tag == "delete": - assert line2 is None + for (line1, line2) in diff_lines(lines1, lines2): + line_color1 = line_color2 = sym_color = Fore.RESET + line_prefix = " " + if line1 and line2 and line1.diff_row == line2.diff_row: + if line1.normalized_original == line2.normalized_original: + out1 = line1.original + out2 = line2.original + elif line1.diff_row == "": + out1 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line1.original}" + out2 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line2.original}" + else: + mnemonic = line1.original.split()[0] + out1, out2 = line1.original, line2.original + branch1 = branch2 = "" + if mnemonic in instructions_with_address_immediates: + out1, branch1 = split_off_branch(line1.original) + out2, branch2 = split_off_branch(line2.original) + branchless1 = out1 + branchless2 = out2 + out1, out2 = color_fields(re_imm, out1, out2, lambda s: f"{Fore.LIGHTBLUE_EX}{s}{Style.RESET_ALL}") - line_color1 = line_color2 = sym_color = Fore.RESET - line_prefix = " " - if line1 and line2 and line1.diff_row == line2.diff_row: - if maybe_normalize_large_imms( - line1.original - ) == maybe_normalize_large_imms(line2.original): - out1 = line1.original - out2 = line2.original - elif line1.diff_row == "": - out1 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line1.original}" - out2 = f"{Style.BRIGHT}{Fore.LIGHTBLACK_EX}{line2.original}" - else: - mnemonic = line1.original.split()[0] - out1, out2 = line1.original, line2.original - branch1 = branch2 = "" - if mnemonic in instructions_with_address_immediates: - out1, branch1 = split_off_branch(line1.original) - out2, branch2 = split_off_branch(line2.original) - branchless1 = out1 - branchless2 = out2 - out1, out2 = color_imms(out1, out2) + same_relative_target = False + if line1.branch_target is not None and line2.branch_target is not None: + relative_target1 = eval_line_num(line1.branch_target) - eval_line_num(line1.line_num) + relative_target2 = eval_line_num(line2.branch_target) - eval_line_num(line2.line_num) + same_relative_target = relative_target1 == relative_target2 - same_relative_target = False - if line1.branch_target is not None and line2.branch_target is not None: - relative_target1 = eval_line_num(line1.branch_target) - eval_line_num(line1.line_num) - relative_target2 = eval_line_num(line2.branch_target) - eval_line_num(line2.line_num) - same_relative_target = relative_target1 == relative_target2 + if not same_relative_target: + branch1, branch2 = color_branch_imms(branch1, branch2) + out1 += branch1 + out2 += branch2 + if normalize_imms(branchless1) == normalize_imms(branchless2): if not same_relative_target: - branch1, branch2 = color_branch_imms(branch1, branch2) - - out1 += branch1 - out2 += branch2 - if normalize_imms(branchless1) == normalize_imms(branchless2): - if not same_relative_target: - # only imms differences - sym_color = Fore.LIGHTBLUE_EX - line_prefix = "i" + # only imms differences + sym_color = Fore.LIGHTBLUE_EX + line_prefix = "i" + else: + out1, out2 = color_fields(re_sprel, out1, out2, sc3.color_symbol, sc4.color_symbol) + if normalize_stack(branchless1) == normalize_stack(branchless2): + # only stack differences (luckily stack and imm + # differences can't be combined in MIPS, so we + # don't have to think about that case) + sym_color = Fore.YELLOW + line_prefix = "s" else: - out1 = re.sub( - re_sprel, lambda s: sc3.color_symbol(s.group()), out1, - ) - out2 = re.sub( - re_sprel, lambda s: sc4.color_symbol(s.group()), out2, - ) - if normalize_stack(branchless1) == normalize_stack(branchless2): - # only stack differences (luckily stack and imm - # differences can't be combined in MIPS, so we - # don't have to think about that case) - sym_color = Fore.YELLOW - line_prefix = "s" - else: - # regs differences and maybe imms as well - out1 = re.sub( - re_reg, lambda s: sc1.color_symbol(s.group()), out1 + # regs differences and maybe imms as well + out1, out2 = color_fields(re_reg, out1, out2, sc1.color_symbol, sc2.color_symbol) + line_color1 = line_color2 = sym_color = Fore.YELLOW + line_prefix = "r" + elif line1 and line2: + line_prefix = "|" + line_color1 = Fore.LIGHTBLUE_EX + line_color2 = Fore.LIGHTBLUE_EX + sym_color = Fore.LIGHTBLUE_EX + out1 = line1.original + out2 = line2.original + elif line1: + line_prefix = "<" + line_color1 = sym_color = Fore.RED + out1 = line1.original + out2 = "" + elif line2: + line_prefix = ">" + line_color2 = sym_color = Fore.GREEN + out1 = "" + out2 = line2.original + + if args.source and line2 and line2.comment: + out2 += f" {line2.comment}" + + def format_part( + out: str, + line: Optional[Line], + line_color: str, + btset: Set[str], + sc: SymbolColorer, + ) -> Optional[str]: + if line is None: + return None + in_arrow = " " + out_arrow = "" + if args.show_branches: + if line.line_num in btset: + in_arrow = sc.color_symbol(line.line_num, "~>") + line_color + if line.branch_target is not None: + out_arrow = " " + sc.color_symbol(line.branch_target + ":", "~>") + out = pad_mnemonic(out) + return f"{line_color}{line.line_num} {in_arrow} {out}{Style.RESET_ALL}{out_arrow}" + + part1 = format_part(out1, line1, line_color1, bts1, sc5) + part2 = format_part(out2, line2, line_color2, bts2, sc6) + key2 = line2.original if line2 else None + + mid = f"{sym_color}{line_prefix}" + + if line2: + for source_line in line2.source_lines: + color = Style.DIM + # File names and function names + if source_line and source_line[0] != "│": + color += Style.BRIGHT + # Function names + if source_line.endswith("():"): + # Underline. Colorama does not provide this feature, unfortunately. + color += "\u001b[4m" + try: + source_line = cxxfilt.demangle( + source_line[:-3], external_only=False ) - out2 = re.sub( - re_reg, lambda s: sc2.color_symbol(s.group()), out2 - ) - line_color1 = line_color2 = sym_color = Fore.YELLOW - line_prefix = "r" - elif line1 and line2: - line_prefix = "|" - line_color1 = Fore.LIGHTBLUE_EX - line_color2 = Fore.LIGHTBLUE_EX - sym_color = Fore.LIGHTBLUE_EX - out1 = line1.original - out2 = line2.original - elif line1: - line_prefix = "<" - line_color1 = sym_color = Fore.RED - out1 = line1.original - out2 = "" - elif line2: - line_prefix = ">" - line_color2 = sym_color = Fore.GREEN - out1 = "" - out2 = line2.original + except: + pass + output.append( + OutputLine( + None, + f" {color}{source_line}{Style.RESET_ALL}", + source_line, + ) + ) - if args.source and line2 and line2.comment: - out2 += f" {line2.comment}" - - def format_part(out: str, line: Optional[Line], line_color: str, btset: Set[str], sc: SymbolColorer) -> Optional[str]: - if line is None: - return None - in_arrow = " " - out_arrow = "" - if args.show_branches: - if line.line_num in btset: - in_arrow = sc.color_symbol(line.line_num, "~>") + line_color - if line.branch_target is not None: - out_arrow = " " + sc.color_symbol(line.branch_target + ":", "~>") - out = pad_mnemonic(out) - return f"{line_color}{line.line_num} {in_arrow} {out}{Style.RESET_ALL}{out_arrow}" - - part1 = format_part(out1, line1, line_color1, bts1, sc5) - part2 = format_part(out2, line2, line_color2, bts2, sc6) - key2 = line2.original if line2 else "" - - mid = f"{sym_color}{line_prefix}" - - if line2: - for source_line in line2.source_lines: - color = Style.DIM - # File names and function names - if source_line and source_line[0] != "|": - color += Style.BRIGHT - # Function names - if source_line.endswith("():"): - # Underline. Colorama does not provide this feature, unfortunately. - color += "\u001b[4m" - try: - source_line = cxxfilt.demangle( - source_line[:-3], external_only=False - ) - except: - pass - output.append(OutputLine(None, f" {color}{source_line}{Style.RESET_ALL}", source_line)) - - fmt2 = mid + " " + (part2 or "") - output.append(OutputLine(part1, fmt2, key2)) + fmt2 = mid + " " + (part2 or "") + output.append(OutputLine(part1, fmt2, key2)) return output @@ -1001,12 +1289,14 @@ def chunk_diff(diff: List[OutputLine]) -> List[Union[List[OutputLine], OutputLin return chunks -def format_diff(old_diff: List[OutputLine], new_diff: List[OutputLine]) -> Tuple[str, List[str]]: +def format_diff( + old_diff: List[OutputLine], new_diff: List[OutputLine] +) -> Tuple[str, List[str]]: old_chunks = chunk_diff(old_diff) new_chunks = chunk_diff(new_diff) output: List[Tuple[str, OutputLine, OutputLine]] = [] assert len(old_chunks) == len(new_chunks), "same target" - empty = OutputLine("", "", "") + empty = OutputLine("", "", None) for old_chunk, new_chunk in zip(old_chunks, new_chunks): if isinstance(old_chunk, list): assert isinstance(new_chunk, list) @@ -1019,18 +1309,19 @@ def format_diff(old_diff: List[OutputLine], new_diff: List[OutputLine]) -> Tuple if tag in ["equal", "replace"]: for i, j in zip(range(i1, i2), range(j1, j2)): output.append(("", old_chunk[i], new_chunk[j])) - elif tag == "insert": - for j in range(j1, j2): + if tag in ["insert", "replace"]: + for j in range(j1 + i2 - i1, j2): output.append(("", empty, new_chunk[j])) - else: - for i in range(i1, i2): + if tag in ["delete", "replace"]: + for i in range(i1 + j2 - j1, i2): output.append(("", old_chunk[i], empty)) else: assert isinstance(new_chunk, OutputLine) + assert new_chunk.base # old_chunk.base and new_chunk.base have the same text since # both diffs are based on the same target, but they might # differ in color. Use the new version. - output.append((new_chunk.base or "", old_chunk, new_chunk)) + output.append((new_chunk.base, old_chunk, new_chunk)) # TODO: status line, with e.g. approximate permuter score? width = args.column_width @@ -1047,29 +1338,35 @@ def format_diff(old_diff: List[OutputLine], new_diff: List[OutputLine]) -> Tuple diff_lines = [ ansi_ljust(base, width) + new.fmt2 for (base, old, new) in output - if base or new.key2 + if base or new.key2 is not None ] return header_line, diff_lines -def debounced_fs_watch(targets, outq, debounce_delay): +def debounced_fs_watch( + targets: List[str], + outq: "queue.Queue[Optional[float]]", + debounce_delay: float, +) -> None: import watchdog.events # type: ignore import watchdog.observers # type: ignore - class WatchEventHandler(watchdog.events.FileSystemEventHandler): - def __init__(self, queue, file_targets): + class WatchEventHandler(watchdog.events.FileSystemEventHandler): # type: ignore + def __init__( + self, queue: "queue.Queue[float]", file_targets: List[str] + ) -> None: self.queue = queue self.file_targets = file_targets - def on_modified(self, ev): + def on_modified(self, ev: object) -> None: if isinstance(ev, watchdog.events.FileModifiedEvent): self.changed(ev.src_path) - def on_moved(self, ev): + def on_moved(self, ev: object) -> None: if isinstance(ev, watchdog.events.FileMovedEvent): self.changed(ev.dest_path) - def should_notify(self, path): + def should_notify(self, path: str) -> bool: for target in self.file_targets: if path == target: return True @@ -1079,13 +1376,13 @@ def debounced_fs_watch(targets, outq, debounce_delay): return True return False - def changed(self, path): + def changed(self, path: str) -> None: if self.should_notify(path): self.queue.put(time.time()) - def debounce_thread(): - listenq = queue.Queue() - file_targets = [] + def debounce_thread() -> NoReturn: + listenq: "queue.Queue[float]" = queue.Queue() + file_targets: List[str] = [] event_handler = WatchEventHandler(listenq, file_targets) observer = watchdog.observers.Observer() observed = set() @@ -1121,19 +1418,29 @@ def debounced_fs_watch(targets, outq, debounce_delay): class Display: - def __init__(self, basedump, mydump): + basedump: str + mydump: str + emsg: Optional[str] + last_diff_output: Optional[List[OutputLine]] + pending_update: Optional[Tuple[str, bool]] + ready_queue: "queue.Queue[None]" + watch_queue: "queue.Queue[Optional[float]]" + less_proc: "Optional[subprocess.Popen[bytes]]" + + def __init__(self, basedump: str, mydump: str) -> None: self.basedump = basedump self.mydump = mydump self.emsg = None self.last_diff_output = None - def run_less(self): + def run_less(self) -> "Tuple[subprocess.Popen[bytes], subprocess.Popen[bytes]]": if self.emsg is not None: output = self.emsg else: diff_output = do_diff(self.basedump, self.mydump) last_diff_output = self.last_diff_output or diff_output - self.last_diff_output = diff_output + if args.threeway != "base" or not self.last_diff_output: + self.last_diff_output = diff_output header, diff_lines = format_diff(last_diff_output, diff_output) header_lines = [header] if header else [] output = "\n".join(header_lines + diff_lines[args.skip_lines :]) @@ -1146,17 +1453,19 @@ class Display: BUFFER_CMD, stdin=subprocess.PIPE, stdout=subprocess.PIPE ) less_proc = subprocess.Popen(LESS_CMD, stdin=buffer_proc.stdout) + assert buffer_proc.stdin + assert buffer_proc.stdout buffer_proc.stdin.write(output.encode()) buffer_proc.stdin.close() buffer_proc.stdout.close() return (buffer_proc, less_proc) - def run_sync(self): + def run_sync(self) -> None: proca, procb = self.run_less() procb.wait() proca.wait() - def run_async(self, watch_queue): + def run_async(self, watch_queue: "queue.Queue[Optional[float]]") -> None: self.watch_queue = watch_queue self.ready_queue = queue.Queue() self.pending_update = None @@ -1164,10 +1473,10 @@ class Display: dthread.start() self.ready_queue.get() - def display_thread(self): + def display_thread(self) -> None: proca, procb = self.run_less() self.less_proc = procb - self.ready_queue.put(0) + self.ready_queue.put(None) while True: ret = procb.wait() proca.wait() @@ -1186,19 +1495,19 @@ class Display: self.emsg = msg proca, procb = self.run_less() self.less_proc = procb - self.ready_queue.put(0) + self.ready_queue.put(None) else: # terminated by user, or killed self.watch_queue.put(None) - self.ready_queue.put(0) + self.ready_queue.put(None) break - def progress(self, msg): + def progress(self, msg: str) -> None: # Write message to top-left corner sys.stdout.write("\x1b7\x1b[1;1f{}\x1b8".format(msg + " ")) sys.stdout.flush() - def update(self, text, error): + def update(self, text: str, error: bool) -> None: if not error and not self.emsg and text == self.mydump: self.progress("Unchanged. ") return @@ -1208,14 +1517,14 @@ class Display: self.less_proc.kill() self.ready_queue.get() - def terminate(self): + def terminate(self) -> None: if not self.less_proc: return self.less_proc.kill() self.ready_queue.get() -def main(): +def main() -> None: if args.diff_elf_symbol: make_target, basecmd, mycmd = dump_elf() elif args.diff_obj: @@ -1245,23 +1554,27 @@ def main(): else: if not args.make: yn = input( - "Warning: watch-mode (-w) enabled without auto-make (-m). You will have to run make manually. Ok? (Y/n) " + "Warning: watch-mode (-w) enabled without auto-make (-m). " + "You will have to run make manually. Ok? (Y/n) " ) if yn.lower() == "n": return if args.make: watch_sources = None - if hasattr(diff_settings, "watch_sources_for_target"): - watch_sources = diff_settings.watch_sources_for_target(make_target) + watch_sources_for_target_fn = getattr( + diff_settings, "watch_sources_for_target", None + ) + if watch_sources_for_target_fn: + watch_sources = watch_sources_for_target_fn(make_target) watch_sources = watch_sources or source_directories if not watch_sources: fail("Missing source_directories config, don't know what to watch.") else: watch_sources = [make_target] - q = queue.Queue() + q: "queue.Queue[Optional[float]]" = queue.Queue() debounced_fs_watch(watch_sources, q, DEBOUNCE_DELAY) display.run_async(q) - last_build = 0 + last_build = 0.0 try: while True: t = q.get() @@ -1272,7 +1585,7 @@ def main(): last_build = time.time() if args.make: display.progress("Building...") - ret = run_make(make_target, capture_output=True) + ret = run_make_capture_output(make_target) if ret.returncode != 0: display.update( ret.stderr.decode("utf-8-sig", "replace")