diff --git a/diff.py b/diff.py index fbc06cb2f6..0f22a38c77 100755 --- a/diff.py +++ b/diff.py @@ -40,7 +40,7 @@ if __name__ == "__main__": sys.path.pop(0) try: - import argcomplete # type: ignore + import argcomplete except ModuleNotFoundError: argcomplete = None @@ -104,31 +104,51 @@ if __name__ == "__main__": "-o", dest="diff_obj", action="store_true", - help="Diff .o files rather than a whole binary. This makes it possible to " - "see symbol names. (Recommended)", + help="""Diff .o files rather than a whole binary. This makes it possible to + see symbol names. (Recommended)""", ) parser.add_argument( "-e", "--elf", dest="diff_elf_symbol", metavar="SYMBOL", - help="Diff a given function in two ELFs, one being stripped and the other " - "one non-stripped. Requires objdump from binutils 2.33+.", + help="""Diff a given function in two ELFs, one being stripped and the other + one non-stripped. Requires objdump from binutils 2.33+.""", ) parser.add_argument( - "--source", "-c", + "--source", + dest="source", action="store_true", help="Show source code (if possible). Only works with -o or -e.", ) parser.add_argument( - "--source-old-binutils", "-C", + "--source-old-binutils", + dest="source_old_binutils", action="store_true", - help="Tweak --source handling to make it work with binutils < 2.33. Implies --source.", + help="""Tweak --source handling to make it work with binutils < 2.33. + Implies --source.""", + ) + parser.add_argument( + "-L", + "--line-numbers", + dest="show_line_numbers", + action="store_const", + const=True, + help="""Show source line numbers in output, when available. May be enabled by + default depending on diff_settings.py.""", + ) + parser.add_argument( + "--no-line-numbers", + dest="show_line_numbers", + action="store_const", + const=False, + help="Hide source line numbers in output.", ) parser.add_argument( "--inlines", + dest="inlines", action="store_true", help="Show inline function calls (if possible). Only works with -o or -e.", ) @@ -155,17 +175,18 @@ if __name__ == "__main__": "-l", "--skip-lines", dest="skip_lines", + metavar="LINES", type=int, default=0, - metavar="LINES", - help="Skip the first N lines of output.", + help="Skip the first LINES lines of output.", ) parser.add_argument( "-s", "--stop-jr-ra", dest="stop_jrra", action="store_true", - help="Stop disassembling at the first 'jr ra'. Some functions have multiple return points, so use with care!", + help="""Stop disassembling at the first 'jr ra'. Some functions have + multiple return points, so use with care!""", ) parser.add_argument( "-i", @@ -192,20 +213,21 @@ if __name__ == "__main__": "-S", "--base-shift", dest="base_shift", + metavar="N", type=str, default="0", - help="Diff position X in our img against position X + shift in the base img. " - 'Arithmetic is allowed, so e.g. |-S "0x1234 - 0x4321"| is a reasonable ' - "flag to pass if it is known that position 0x1234 in the base img syncs " - "up with position 0x4321 in our img. Not supported together with -o.", + help="""Diff position N in our img against position N + shift in the base img. + Arithmetic is allowed, so e.g. |-S "0x1234 - 0x4321"| is a reasonable + flag to pass if it is known that position 0x1234 in the base img syncs + up with position 0x4321 in our img. Not supported together with -o.""", ) parser.add_argument( "-w", "--watch", dest="watch", action="store_true", - help="Automatically update when source/object files change. " - "Recommended in combination with -m.", + help="""Automatically update when source/object files change. + Recommended in combination with -m.""", ) parser.add_argument( "-3", @@ -213,8 +235,8 @@ if __name__ == "__main__": dest="threeway", action="store_const", const="prev", - help="Show a three-way diff between target asm, current asm, and asm " - "prior to -w rebuild. Requires -w.", + help="""Show a three-way diff between target asm, current asm, and asm + prior to -w rebuild. Requires -w.""", ) parser.add_argument( "-b", @@ -222,12 +244,13 @@ if __name__ == "__main__": dest="threeway", action="store_const", const="base", - help="Show a three-way diff between target asm, current asm, and asm " - "when diff.py was started. Requires -w.", + help="""Show a three-way diff between target asm, current asm, and asm + when diff.py was started. Requires -w.""", ) parser.add_argument( "--width", dest="column_width", + metavar="COLS", type=int, default=50, help="Sets the width of the left and right view column.", @@ -237,12 +260,13 @@ if __name__ == "__main__": dest="algorithm", default="levenshtein", choices=["levenshtein", "difflib"], - help="Diff algorithm to use. Levenshtein gives the minimum diff, while difflib " - "aims for long sections of equal opcodes. Defaults to %(default)s.", + help="""Diff algorithm to use. Levenshtein gives the minimum diff, while difflib + aims for long sections of equal opcodes. Defaults to %(default)s.""", ) parser.add_argument( "--max-size", "--max-lines", + metavar="LINES", dest="max_lines", type=int, default=1024, @@ -252,14 +276,32 @@ if __name__ == "__main__": "--no-pager", dest="no_pager", action="store_true", - help="Disable the pager; write output directly to stdout, then exit. " - "Incompatible with --watch.", + help="""Disable the pager; write output directly to stdout, then exit. + Incompatible with --watch.""", ) parser.add_argument( "--format", - choices=("color", "plain", "html"), + choices=("color", "plain", "html", "json"), default="color", - help="Output format, default is color. --format=html implies --no-pager.", + help="Output format, default is color. --format=html or json implies --no-pager.", + ) + parser.add_argument( + "-U", + "--compress-matching", + metavar="N", + dest="compress_matching", + type=int, + help="""Compress streaks of matching lines, leaving N lines of context + around non-matching parts.""", + ) + parser.add_argument( + "-V", + "--compress-sameinstr", + metavar="N", + dest="compress_sameinstr", + type=int, + help="""Compress streaks of lines with same instructions (but possibly + different regalloc), leaving N lines of context around other parts.""", ) # Project-specific flags, e.g. different versions/make arguments. @@ -276,29 +318,32 @@ if __name__ == "__main__": import abc import ast -from dataclasses import dataclass, field, replace +from collections import Counter, defaultdict +from dataclasses import asdict, dataclass, field, replace import difflib import enum import html import itertools +import json import os import queue import re import string +import struct import subprocess import threading import time +import traceback MISSING_PREREQUISITES = ( "Missing prerequisite python module {}. " - "Run `python3 -m pip install --user colorama ansiwrap watchdog python-Levenshtein cxxfilt` to install prerequisites (cxxfilt only needed with --source)." + "Run `python3 -m pip install --user colorama watchdog python-Levenshtein cxxfilt` to install prerequisites (cxxfilt only needed with --source)." ) try: - from colorama import Fore, Style # type: ignore - import ansiwrap # type: ignore - import watchdog # type: ignore + from colorama import Back, Fore, Style + import watchdog except ModuleNotFoundError as e: fail(MISSING_PREREQUISITES.format(e.name)) @@ -317,6 +362,13 @@ class ProjectSettings: mapfile: Optional[str] source_directories: Optional[List[str]] source_extensions: List[str] + show_line_numbers_default: bool + + +@dataclass +class Compress: + context: int + same_instr: bool @dataclass @@ -337,12 +389,22 @@ class Config: threeway: Optional[str] base_shift: int skip_lines: int + compress: Optional[Compress] show_branches: bool + show_line_numbers: bool stop_jrra: bool ignore_large_imms: bool ignore_addr_diffs: bool algorithm: str + # Score options + score_stack_differences = True + penalty_stackdiff = 1 + penalty_regalloc = 5 + penalty_reordering = 60 + penalty_insertion = 100 + penalty_deletion = 100 + def create_project_settings(settings: Dict[str, Any]) -> ProjectSettings: return ProjectSettings( @@ -360,6 +422,7 @@ def create_project_settings(settings: Dict[str, Any]) -> ProjectSettings: objdump_executable=get_objdump_executable(settings.get("objdump_executable")), map_format=settings.get("map_format", "gnu"), mw_build_dir=settings.get("mw_build_dir", "build/"), + show_line_numbers_default=settings.get("show_line_numbers_default", True), ) @@ -371,9 +434,25 @@ def create_config(args: argparse.Namespace, project: ProjectSettings) -> Config: formatter = AnsiFormatter(column_width=args.column_width) elif args.format == "html": formatter = HtmlFormatter() + elif args.format == "json": + formatter = JsonFormatter(arch_str=project.arch_str) else: raise ValueError(f"Unsupported --format: {args.format}") + compress = None + if args.compress_matching is not None: + compress = Compress(args.compress_matching, False) + if args.compress_sameinstr is not None: + if compress is not None: + raise ValueError( + "Cannot pass both --compress-matching and --compress-sameinstr" + ) + compress = Compress(args.compress_sameinstr, True) + + show_line_numbers = args.show_line_numbers + if show_line_numbers is None: + show_line_numbers = project.show_line_numbers_default + return Config( arch=get_arch(project.arch_str), # Build/objdump options @@ -391,7 +470,9 @@ def create_config(args: argparse.Namespace, project: ProjectSettings) -> Config: args.base_shift, "Failed to parse --base-shift (-S) argument as an integer." ), skip_lines=args.skip_lines, + compress=compress, show_branches=args.show_branches, + show_line_numbers=show_line_numbers, stop_jrra=args.stop_jrra, ignore_large_imms=args.ignore_large_imms, ignore_addr_diffs=args.ignore_addr_diffs, @@ -457,6 +538,7 @@ class BasicFormat(enum.Enum): DIFF_REMOVE = enum.auto() SOURCE_FILENAME = enum.auto() SOURCE_FUNCTION = enum.auto() + SOURCE_LINE_NUM = enum.auto() SOURCE_OTHER = enum.auto() @@ -474,14 +556,8 @@ FormatFunction = Callable[[str], Format] class Text: segments: List[Tuple[str, Format]] - def __init__( - self, line: Optional[str] = None, f: Format = BasicFormat.NONE - ) -> None: - self.segments = [] - if line is not None: - self.segments.append((line, f)) - elif f is not BasicFormat.NONE: - raise ValueError("Text constructor provided `f`, but no line to format") + def __init__(self, line: str = "", f: Format = BasicFormat.NONE) -> None: + self.segments = [(line, f)] if line else [] def reformat(self, f: Format) -> "Text": return Text(self.plain(), f) @@ -492,6 +568,9 @@ class Text: def __repr__(self) -> str: return f"" + def __bool__(self) -> bool: + return any(s for s, f in self.segments) + def __str__(self) -> str: # Use Formatter.apply(...) instead return NotImplemented @@ -503,15 +582,25 @@ class Text: if isinstance(other, str): other = Text(other) result = Text() - result.segments = self.segments + other.segments + # If two adjacent segments have the same format, merge their lines + if ( + self.segments + and other.segments + and self.segments[-1][1] == other.segments[0][1] + ): + result.segments = ( + self.segments[:-1] + + [(self.segments[-1][0] + other.segments[0][0], self.segments[-1][1])] + + other.segments[1:] + ) + else: + result.segments = self.segments + other.segments return result def __radd__(self, other: Union["Text", str]) -> "Text": if isinstance(other, str): other = Text(other) - result = Text() - result.segments = other.segments + self.segments - return result + return other + self def finditer(self, pat: Pattern[str]) -> Iterator[Match[str]]: """Replacement for `pat.finditer(text)` that operates on the inner text, @@ -528,12 +617,25 @@ class Text: start, end = match.start(), match.end() assert i <= start <= end <= len(chunk) sub = sub_fn(match) - result.segments.append((chunk[i:start], f)) + if i != start: + result.segments.append((chunk[i:start], f)) result.segments.extend(sub.segments) i = end - result.segments.append((chunk[i:], f)) + if chunk[i:]: + result.segments.append((chunk[i:], f)) return result + def ljust(self, column_width: int) -> "Text": + length = sum(len(x) for x, _ in self.segments) + return self + " " * max(column_width - length, 0) + + +@dataclass +class TableMetadata: + headers: Tuple[Text, ...] + current_score: int + previous_score: Optional[int] + class Formatter(abc.ABC): @abc.abstractmethod @@ -542,15 +644,17 @@ class Formatter(abc.ABC): ... @abc.abstractmethod - def table( - self, header: Optional[Tuple[str, ...]], lines: List[Tuple[str, ...]] - ) -> str: - """Format a multi-column table with an optional `header`""" + def table(self, meta: TableMetadata, lines: List[Tuple["OutputLine", ...]]) -> str: + """Format a multi-column table with metadata""" ... def apply(self, text: Text) -> str: return "".join(self.apply_format(chunk, f) for chunk, f in text.segments) + @staticmethod + def outputline_texts(lines: Tuple["OutputLine", ...]) -> Tuple[Text, ...]: + return tuple([lines[0].base or Text()] + [line.fmt2 for line in lines[1:]]) + @dataclass class PlainFormatter(Formatter): @@ -559,18 +663,21 @@ class PlainFormatter(Formatter): def apply_format(self, chunk: str, f: Format) -> str: return chunk - def table( - self, header: Optional[Tuple[str, ...]], lines: List[Tuple[str, ...]] - ) -> str: - if header: - lines = [header] + lines + def table(self, meta: TableMetadata, lines: List[Tuple["OutputLine", ...]]) -> str: + rows = [meta.headers] + [self.outputline_texts(ls) for ls in lines] return "\n".join( - "".join(x.ljust(self.column_width) for x in line) for line in lines + "".join(self.apply(x.ljust(self.column_width)) for x in row) for row in rows ) @dataclass class AnsiFormatter(Formatter): + # Additional ansi escape codes not in colorama. See: + # https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_(Select_Graphic_Rendition)_parameters + STYLE_UNDERLINE = "\x1b[4m" + STYLE_NO_UNDERLINE = "\x1b[24m" + STYLE_INVERT = "\x1b[7m" + BASIC_ANSI_CODES = { BasicFormat.NONE: "", BasicFormat.IMMEDIATE: Fore.LIGHTBLUE_EX, @@ -581,11 +688,18 @@ class AnsiFormatter(Formatter): BasicFormat.DIFF_ADD: Fore.GREEN, BasicFormat.DIFF_REMOVE: Fore.RED, BasicFormat.SOURCE_FILENAME: Style.DIM + Style.BRIGHT, - # Underline (not in colorama) + bright + dim - BasicFormat.SOURCE_FUNCTION: Style.DIM + Style.BRIGHT + "\u001b[4m", + BasicFormat.SOURCE_FUNCTION: Style.DIM + Style.BRIGHT + STYLE_UNDERLINE, + BasicFormat.SOURCE_LINE_NUM: Fore.LIGHTBLACK_EX, BasicFormat.SOURCE_OTHER: Style.DIM, } + BASIC_ANSI_CODES_UNDO = { + BasicFormat.NONE: "", + BasicFormat.SOURCE_FILENAME: Style.NORMAL, + BasicFormat.SOURCE_FUNCTION: Style.NORMAL + STYLE_NO_UNDERLINE, + BasicFormat.SOURCE_OTHER: Style.NORMAL, + } + ROTATION_ANSI_COLORS = [ Fore.MAGENTA, Fore.CYAN, @@ -603,30 +717,30 @@ class AnsiFormatter(Formatter): def apply_format(self, chunk: str, f: Format) -> str: if f == BasicFormat.NONE: return chunk + undo_ansi_code = Fore.RESET if isinstance(f, BasicFormat): ansi_code = self.BASIC_ANSI_CODES[f] + undo_ansi_code = self.BASIC_ANSI_CODES_UNDO.get(f, undo_ansi_code) elif isinstance(f, RotationFormat): ansi_code = self.ROTATION_ANSI_COLORS[ f.index % len(self.ROTATION_ANSI_COLORS) ] else: static_assert_unreachable(f) - return f"{ansi_code}{chunk}{Style.RESET_ALL}" + return f"{ansi_code}{chunk}{undo_ansi_code}" - def table( - self, header: Optional[Tuple[str, ...]], lines: List[Tuple[str, ...]] - ) -> str: - if header: - lines = [header] + lines - return "\n".join("".join(self.ansi_ljust(x) for x in line) for line in lines) - - def ansi_ljust(self, s: str) -> str: - """Like s.ljust(width), but accounting for ANSI colors.""" - needed: int = self.column_width - ansiwrap.ansilen(s) - if needed > 0: - return s + " " * needed - else: - return s + def table(self, meta: TableMetadata, lines: List[Tuple["OutputLine", ...]]) -> str: + rows = [(meta.headers, False)] + [ + (self.outputline_texts(line), line[1].is_data_ref) for line in lines + ] + return "\n".join( + "".join( + (self.STYLE_INVERT if is_data_ref else "") + + self.apply(x.ljust(self.column_width)) + for x in row + ) + for (row, is_data_ref) in rows + ) @dataclass @@ -648,28 +762,107 @@ class HtmlFormatter(Formatter): static_assert_unreachable(f) return f"{chunk}" - def table( - self, header: Optional[Tuple[str, ...]], lines: List[Tuple[str, ...]] - ) -> str: - def table_row(line: Tuple[str, ...], cell_el: str) -> str: - output_row = " " + def table(self, meta: TableMetadata, lines: List[Tuple["OutputLine", ...]]) -> str: + def table_row(line: Tuple[Text, ...], is_data_ref: bool, cell_el: str) -> str: + tr_attrs = " class='data-ref'" if is_data_ref else "" + output_row = f" " for cell in line: - output_row += f"<{cell_el}>{cell}" + cell_html = self.apply(cell) + output_row += f"<{cell_el}>{cell_html}" output_row += "\n" return output_row output = "\n" - if header: - output += " \n" - output += table_row(header, "th") - output += " \n" + output += " \n" + output += table_row(meta.headers, False, "th") + output += " \n" output += " \n" - output += "".join(table_row(line, "td") for line in lines) + output += "".join( + table_row(self.outputline_texts(line), line[1].is_data_ref, "td") + for line in lines + ) output += " \n" output += "
\n" return output +@dataclass +class JsonFormatter(Formatter): + arch_str: str + + def apply_format(self, chunk: str, f: Format) -> str: + # This method is unused by this formatter + return NotImplemented + + def table(self, meta: TableMetadata, rows: List[Tuple["OutputLine", ...]]) -> str: + def serialize_format(s: str, f: Format) -> Dict[str, Any]: + if f == BasicFormat.NONE: + return {"text": s} + elif isinstance(f, BasicFormat): + return {"text": s, "format": f.name.lower()} + elif isinstance(f, RotationFormat): + attrs = asdict(f) + attrs.update( + { + "text": s, + "format": "rotation", + } + ) + return attrs + else: + static_assert_unreachable(f) + + def serialize(text: Optional[Text]) -> List[Dict[str, Any]]: + if text is None: + return [] + return [serialize_format(s, f) for s, f in text.segments] + + is_threeway = len(meta.headers) == 3 + + output: Dict[str, Any] = {} + output["arch_str"] = self.arch_str + output["header"] = { + name: serialize(h) + for h, name in zip(meta.headers, ("base", "current", "previous")) + } + output["current_score"] = meta.current_score + if meta.previous_score is not None: + output["previous_score"] = meta.previous_score + output_rows: List[Dict[str, Any]] = [] + for row in rows: + output_row: Dict[str, Any] = {} + output_row["key"] = row[0].key2 + output_row["is_data_ref"] = row[1].is_data_ref + iters = [ + ("base", row[0].base, row[0].line1), + ("current", row[1].fmt2, row[1].line2), + ] + if is_threeway: + iters.append(("previous", row[2].fmt2, row[2].line2)) + if all(line is None for _, _, line in iters): + # Skip rows that were only for displaying source code + continue + for column_name, text, line in iters: + column: Dict[str, Any] = {} + column["text"] = serialize(text) + if line: + if line.line_num is not None: + column["line"] = line.line_num + if line.branch_target is not None: + column["branch"] = line.branch_target + if line.source_lines: + column["src"] = line.source_lines + if line.comment is not None: + column["src_comment"] = line.comment + if line.source_line_num is not None: + column["src_line"] = line.source_line_num + if line or column["text"]: + output_row[column_name] = column + output_rows.append(output_row) + output["rows"] = output_rows + return json.dumps(output) + + def format_fields( pat: Pattern[str], out1: Text, @@ -732,8 +925,11 @@ def eval_int(expr: str, emsg: str) -> int: return ret -def eval_line_num(expr: str) -> int: - return int(expr.strip().replace(":", ""), 16) +def eval_line_num(expr: str) -> Optional[int]: + expr = expr.strip().replace(":", "") + if expr == "": + return None + return int(expr, 16) def run_make(target: str, project: ProjectSettings) -> None: @@ -750,34 +946,35 @@ def run_make_capture_output( ) -def restrict_to_function(dump: str, fn_name: str, config: Config) -> str: - out: List[str] = [] - search = f"<{fn_name}>:" - found = False - for line in dump.split("\n"): - if found: - if len(out) >= config.max_function_size_lines: - break - out.append(line) - elif search in line: - found = True - return "\n".join(out) +def restrict_to_function(dump: str, fn_name: str) -> str: + try: + ind = dump.index("\n", dump.index(f"<{fn_name}>:")) + return dump[ind + 1 :] + except ValueError: + return "" + + +def serialize_data_references(references: List[Tuple[int, int, str]]) -> str: + return "".join( + f"DATAREF {text_offset} {from_offset} {from_section}\n" + for (text_offset, from_offset, from_section) in references + ) def maybe_get_objdump_source_flags(config: Config) -> List[str]: - if not config.source: - return [] + flags = [] - flags = [ - "--source", - "-l", - ] + if config.show_line_numbers or config.source: + flags.append("--line-numbers") - if not config.source_old_binutils: - flags.append("--source-comment=│ ") + if config.source: + flags.append("--source") - if config.inlines: - flags.append("--inlines") + if not config.source_old_binutils: + flags.append("--source-comment=│ ") + + if config.inlines: + flags.append("--inlines") return flags @@ -800,7 +997,16 @@ def run_objdump(cmd: ObjdumpCommand, config: Config, project: ProjectSettings) - raise e if restrict is not None: - return restrict_to_function(out, restrict, config) + out = restrict_to_function(out, restrict) + + if config.diff_obj: + with open(target, "rb") as f: + data = f.read() + out = serialize_data_references(parse_elf_data_references(data)) + out + else: + for i in range(7): + out = out[out.find("\n") + 1 :] + out = out.rstrip("\n") return out @@ -838,8 +1044,6 @@ def search_map_file( cands.append((cur_objfile, ram + ram_to_rom)) last_line = line except Exception as e: - import traceback - traceback.print_exc() fail(f"Internal error while parsing map file") @@ -889,6 +1093,114 @@ def search_map_file( return None, None +def parse_elf_data_references(data: bytes) -> List[Tuple[int, int, str]]: + e_ident = data[:16] + if e_ident[:4] != b"\x7FELF": + return [] + + SHT_SYMTAB = 2 + SHT_REL = 9 + SHT_RELA = 4 + + is_32bit = e_ident[4] == 1 + is_little_endian = e_ident[5] == 1 + str_end = "<" if is_little_endian else ">" + str_off = "I" if is_32bit else "Q" + sym_size = {"B": 1, "H": 2, "I": 4, "Q": 8} + + def read(spec: str, offset: int) -> Tuple[int, ...]: + spec = spec.replace("P", str_off) + size = struct.calcsize(spec) + return struct.unpack(str_end + spec, data[offset : offset + size]) + + ( + e_type, + e_machine, + e_version, + e_entry, + e_phoff, + e_shoff, + e_flags, + e_ehsize, + e_phentsize, + e_phnum, + e_shentsize, + e_shnum, + e_shstrndx, + ) = read("HHIPPPIHHHHHH", 16) + if e_type != 1: # relocatable + return [] + assert e_shoff != 0 + assert e_shnum != 0 # don't support > 0xFF00 sections + assert e_shstrndx != 0 + + @dataclass + class Section: + sh_name: int + sh_type: int + sh_flags: int + sh_addr: int + sh_offset: int + sh_size: int + sh_link: int + sh_info: int + sh_addralign: int + sh_entsize: int + + sections = [ + Section(*read("IIPPPPIIPP", e_shoff + i * e_shentsize)) for i in range(e_shnum) + ] + shstr = sections[e_shstrndx] + sec_name_offs = [shstr.sh_offset + s.sh_name for s in sections] + sec_names = [data[offset : data.index(b"\0", offset)] for offset in sec_name_offs] + + symtab_sections = [i for i in range(e_shnum) if sections[i].sh_type == SHT_SYMTAB] + assert len(symtab_sections) == 1 + symtab = sections[symtab_sections[0]] + + text_sections = [i for i in range(e_shnum) if sec_names[i] == b".text"] + assert len(text_sections) == 1 + text_section = text_sections[0] + + ret: List[Tuple[int, int, str]] = [] + for s in sections: + if s.sh_type == SHT_REL or s.sh_type == SHT_RELA: + if s.sh_info == text_section: + # Skip .text -> .text references + continue + sec_name = sec_names[s.sh_info].decode("latin1") + sec_base = sections[s.sh_info].sh_offset + for i in range(0, s.sh_size, s.sh_entsize): + if s.sh_type == SHT_REL: + r_offset, r_info = read("PP", s.sh_offset + i) + else: + r_offset, r_info, r_addend = read("PPP", s.sh_offset + i) + + if is_32bit: + r_sym = r_info >> 8 + r_type = r_info & 0xFF + sym_offset = symtab.sh_offset + symtab.sh_entsize * r_sym + st_name, st_value, st_size, st_info, st_other, st_shndx = read( + "IIIBBH", sym_offset + ) + else: + r_sym = r_info >> 32 + r_type = r_info & 0xFFFFFFFF + sym_offset = symtab.sh_offset + symtab.sh_entsize * r_sym + st_name, st_info, st_other, st_shndx, st_value, st_size = read( + "IBBHQQ", sym_offset + ) + if st_shndx == text_section: + if s.sh_type == SHT_REL: + if e_machine == 8 and r_type == 2: # R_MIPS_32 + (r_addend,) = read("I", sec_base + r_offset) + else: + continue + text_offset = (st_value + r_addend) & 0xFFFFFFFF + ret.append((text_offset, r_offset, sec_name)) + return ret + + def dump_elf( start: str, end: Optional[str], @@ -953,7 +1265,7 @@ def dump_objfile( if not os.path.isfile(refobjfile): fail(f'Please ensure an OK .o file exists at "{refobjfile}".') - objdump_flags = ["-drz"] + objdump_flags = ["-drz", "-j", ".text"] return ( objfile, (objdump_flags, refobjfile, start), @@ -996,8 +1308,9 @@ class DifferenceNormalizer: def normalize(self, mnemonic: str, row: str) -> str: """This should be called exactly once for each line.""" + arch = self.config.arch row = self._normalize_arch_specific(mnemonic, row) - if self.config.ignore_large_imms: + if self.config.ignore_large_imms and mnemonic not in arch.branch_instructions: row = re.sub(self.config.arch.re_large_imm, "", row) return row @@ -1020,8 +1333,8 @@ class DifferenceNormalizerAArch64(DifferenceNormalizer): if mnemonic != "bl": return row - row, _ = split_off_branch(row) - return row + row, _ = split_off_address(row) + return row + "" def _normalize_adrp_differences(self, mnemonic: str, row: str) -> str: """Identifies ADRP + LDR/ADD pairs that are used to access the GOT and @@ -1036,7 +1349,8 @@ class DifferenceNormalizerAArch64(DifferenceNormalizer): row_parts = row.split("\t", 1) if mnemonic == "adrp": self._adrp_pair_registers.add(row_parts[1].strip().split(",")[0]) - row, _ = split_off_branch(row) + row, _ = split_off_address(row) + return row + "" elif mnemonic == "ldr": for reg in self._adrp_pair_registers: # ldr xxx, [reg] @@ -1301,50 +1615,75 @@ class Line: diff_row: str original: str normalized_original: str - line_num: str - branch_target: Optional[str] - source_lines: List[str] - comment: Optional[str] + scorable_line: str + line_num: Optional[int] = None + branch_target: Optional[int] = None + source_filename: Optional[str] = None + source_line_num: Optional[int] = None + source_lines: List[str] = field(default_factory=list) + comment: Optional[str] = None -def process(lines: List[str], config: Config) -> List[Line]: +def process(dump: str, config: Config) -> List[Line]: arch = config.arch normalizer = arch.difference_normalizer(config) skip_next = False source_lines = [] - if not config.diff_obj: - lines = lines[7:] - if lines and not lines[-1]: - lines.pop() + source_filename = None + source_line_num = None + i = 0 + num_instr = 0 + data_refs: Dict[int, Dict[str, List[int]]] = defaultdict(lambda: defaultdict(list)) output: List[Line] = [] stop_after_delay_slot = False - for row in lines: + lines = dump.split("\n") + while i < len(lines): + row = lines[i] + i += 1 + if config.diff_obj and (">:" in row or not row): continue + if row.startswith("DATAREF"): + parts = row.split(" ", 3) + text_offset = int(parts[1]) + from_offset = int(parts[2]) + from_section = parts[3] + data_refs[text_offset][from_section].append(from_offset) + continue + + if config.diff_obj and num_instr >= config.max_function_size_lines: + output.append( + Line( + mnemonic="...", + diff_row="...", + original="...", + normalized_original="...", + scorable_line="...", + ) + ) + break + + # This regex is conservative, and assumes the file path does not contain "weird" + # characters like colons, tabs, or angle brackets. + if ( + config.show_line_numbers + and row + and re.match( + r"^[^ \t<>:][^\t<>:]*:[0-9]+( \(discriminator [0-9]+\))?$", row + ) + ): + source_filename, _, tail = row.rpartition(":") + source_line_num = int(tail.partition(" ")[0]) + if config.source: + source_lines.append(row) + continue + if config.source and not config.source_old_binutils and (row and row[0] != " "): source_lines.append(row) continue - if "R_AARCH64_" in row: - # TODO: handle relocation - continue - - if "R_MIPS_" in row: - # N.B. Don't transform the diff rows, they already ignore immediates - # if output[-1].diff_row != "": - # output[-1] = output[-1].replace(diff_row=process_mips_reloc(row, output[-1].row_with_imm, arch)) - new_original = process_mips_reloc(row, output[-1].original, arch) - output[-1] = replace(output[-1], original=new_original) - continue - - if "R_PPC_" in row: - new_original = process_ppc_reloc(row, output[-1].original) - output[-1] = replace(output[-1], original=new_original) - continue - - # match source lines here to avoid matching relocation lines if ( config.source and config.source_old_binutils @@ -1353,13 +1692,33 @@ def process(lines: List[str], config: Config) -> List[Line]: source_lines.append(row) continue + # `objdump --line-numbers` includes function markers, even without `--source` + if config.show_line_numbers and row and re.match(r"^[^ \t]+\(\):$", row): + continue + m_comment = re.search(arch.re_comment, row) comment = m_comment[0] if m_comment else None row = re.sub(arch.re_comment, "", row) row = row.rstrip() tabs = row.split("\t") row = "\t".join(tabs[2:]) - line_num = tabs[0].strip() + line_num = eval_line_num(tabs[0].strip()) + + if line_num in data_refs: + refs = data_refs[line_num] + ref_str = "; ".join( + section_name + "+" + ",".join(hex(off) for off in offs) + for section_name, offs in refs.items() + ) + output.append( + Line( + mnemonic="", + diff_row="", + original=ref_str, + normalized_original=ref_str, + scorable_line="", + ) + ) if "\t" in row: row_parts = row.split("\t", 1) @@ -1370,30 +1729,57 @@ def process(lines: List[str], config: Config) -> List[Line]: if mnemonic not in arch.instructions_with_address_immediates: row = re.sub(arch.re_int, lambda m: hexify_int(row, m, arch), row) + + # Let 'original' be 'row' with relocations applied, while we continue + # transforming 'row' into a coarser version that ignores registers and + # immediates. original = row + + while i < len(lines): + reloc_row = lines[i] + if "R_AARCH64_" in reloc_row: + # TODO: handle relocation + pass + elif "R_MIPS_" in reloc_row: + original = process_mips_reloc(reloc_row, original, arch) + elif "R_PPC_" in reloc_row: + original = process_ppc_reloc(reloc_row, original) + else: + break + i += 1 + normalized_original = normalizer.normalize(mnemonic, original) + + scorable_line = normalized_original + if not config.score_stack_differences: + scorable_line = re.sub(arch.re_sprel, "addr(sp)", scorable_line) + if mnemonic in arch.branch_instructions: + # Replace the final argument with "" + scorable_line = re.sub(r"[^, \t]+$", "", scorable_line) + if skip_next: skip_next = False row = "" mnemonic = "" + scorable_line = "" if mnemonic in arch.branch_likely_instructions: skip_next = True + row = re.sub(arch.re_reg, "", row) row = re.sub(arch.re_sprel, "addr(sp)", row) row_with_imm = row if mnemonic in arch.instructions_with_address_immediates: row = row.strip() - row, _ = split_off_branch(row) + row, _ = split_off_address(row) row += "" else: row = normalize_imms(row, arch) branch_target = None if mnemonic in arch.branch_instructions: - target = int(row_parts[1].strip().split(",")[-1], 16) + branch_target = int(row_parts[1].strip().split(",")[-1], 16) if mnemonic in arch.branch_likely_instructions: - target -= 4 - branch_target = hex(target)[2:] + branch_target -= 4 output.append( Line( @@ -1401,12 +1787,16 @@ def process(lines: List[str], config: Config) -> List[Line]: diff_row=row, original=original, normalized_original=normalized_original, + scorable_line=scorable_line, line_num=line_num, branch_target=branch_target, + source_filename=source_filename, + source_line_num=source_line_num, source_lines=source_lines, comment=comment, ) ) + num_instr += 1 source_lines = [] if config.stop_jrra and mnemonic == "jr" and row_parts[1].strip() == "ra": @@ -1425,7 +1815,13 @@ def normalize_stack(row: str, arch: ArchSettings) -> str: return re.sub(arch.re_sprel, "addr(sp)", row) -def split_off_branch(line: str) -> Tuple[str, str]: +def imm_matches_everything(row: str, arch: ArchSettings) -> bool: + # (this should probably be arch-specific) + return "(." in row + + +def split_off_address(line: str) -> Tuple[str, str]: + """Split e.g. 'beqz $r0,1f0' into 'beqz $r0,' and '1f0'.""" parts = line.split(",") if len(parts) < 2: parts = line.split(None, 1) @@ -1466,9 +1862,10 @@ def diff_sequences( rem1 = remap(seq1) rem2 = remap(seq2) - import Levenshtein # type: ignore + import Levenshtein - return Levenshtein.opcodes(rem1, rem2) # type: ignore + ret: List[Tuple[str, int, int, int, int]] = Levenshtein.opcodes(rem1, rem2) + return ret def diff_lines( @@ -1497,31 +1894,152 @@ def diff_lines( return ret +def score_diff_lines( + lines: List[Tuple[Optional[Line], Optional[Line]]], config: Config +) -> int: + # This logic is copied from `scorer.py` from the decomp permuter project + # https://github.com/simonlindholm/decomp-permuter/blob/main/src/scorer.py + score = 0 + deletions = [] + insertions = [] + + def lo_hi_match(old: str, new: str) -> bool: + # TODO: Make this arch-independent, like `imm_matches_everything()` + old_lo = old.find("%lo") + old_hi = old.find("%hi") + new_lo = new.find("%lo") + new_hi = new.find("%hi") + + if old_lo != -1 and new_lo != -1: + old_idx = old_lo + new_idx = new_lo + elif old_hi != -1 and new_hi != -1: + old_idx = old_hi + new_idx = new_hi + else: + return False + + if old[:old_idx] != new[:new_idx]: + return False + + old_inner = old[old_idx + 4 : -1] + new_inner = new[new_idx + 4 : -1] + return old_inner.startswith(".") or new_inner.startswith(".") + + def diff_sameline(old: str, new: str) -> None: + nonlocal score + if old == new: + return + + if lo_hi_match(old, new): + return + + ignore_last_field = False + if config.score_stack_differences: + oldsp = re.search(config.arch.re_sprel, old) + newsp = re.search(config.arch.re_sprel, new) + if oldsp and newsp: + oldrel = int(oldsp.group(1) or "0", 0) + newrel = int(newsp.group(1) or "0", 0) + score += abs(oldrel - newrel) * config.penalty_stackdiff + ignore_last_field = True + + # Probably regalloc difference, or signed vs unsigned + + # Compare each field in order + newfields, oldfields = new.split(","), old.split(",") + if ignore_last_field: + newfields = newfields[:-1] + oldfields = oldfields[:-1] + for nf, of in zip(newfields, oldfields): + if nf != of: + score += config.penalty_regalloc + # Penalize any extra fields + score += abs(len(newfields) - len(oldfields)) * config.penalty_regalloc + + def diff_insert(line: str) -> None: + # Reordering or totally different codegen. + # Defer this until later when we can tell. + insertions.append(line) + + def diff_delete(line: str) -> None: + deletions.append(line) + + # Find the end of the last long streak of matching mnemonics, if it looks + # like the objdump output was truncated. This is used to skip scoring + # misaligned lines at the end of the diff. + last_mismatch = -1 + max_index = None + lines_were_truncated = False + for index, (line1, line2) in enumerate(lines): + if (line1 and line1.original == "...") or (line2 and line2.original == "..."): + lines_were_truncated = True + if line1 and line2 and line1.mnemonic == line2.mnemonic: + if index - last_mismatch >= 50: + max_index = index + else: + last_mismatch = index + if not lines_were_truncated: + max_index = None + + for index, (line1, line2) in enumerate(lines): + if max_index is not None and index > max_index: + break + if line1 and line2 and line1.mnemonic == line2.mnemonic: + diff_sameline(line1.scorable_line, line2.scorable_line) + else: + if line1: + diff_delete(line1.scorable_line) + if line2: + diff_insert(line2.scorable_line) + + insertions_co = Counter(insertions) + deletions_co = Counter(deletions) + for item in insertions_co + deletions_co: + ins = insertions_co[item] + dels = deletions_co[item] + common = min(ins, dels) + score += ( + (ins - common) * config.penalty_insertion + + (dels - common) * config.penalty_deletion + + config.penalty_reordering * common + ) + + return score + + @dataclass(frozen=True) class OutputLine: base: Optional[Text] = field(compare=False) fmt2: Text = field(compare=False) key2: Optional[str] + boring: bool = field(compare=False) + is_data_ref: bool = field(compare=False) + line1: Optional[Line] = field(compare=False) + line2: Optional[Line] = field(compare=False) -def do_diff(basedump: str, mydump: str, config: Config) -> List[OutputLine]: +@dataclass(frozen=True) +class Diff: + lines: List[OutputLine] + score: int + + +def do_diff(lines1: List[Line], lines2: List[Line], config: Config) -> Diff: if config.source: - import cxxfilt # type: ignore + import cxxfilt arch = config.arch fmt = config.formatter output: List[OutputLine] = [] - lines1 = process(basedump.split("\n"), config) - lines2 = process(mydump.split("\n"), config) - sc1 = symbol_formatter("base-reg", 0) sc2 = symbol_formatter("my-reg", 0) sc3 = symbol_formatter("base-stack", 4) sc4 = symbol_formatter("my-stack", 4) sc5 = symbol_formatter("base-branch", 0) sc6 = symbol_formatter("my-branch", 0) - bts1: Set[str] = set() - bts2: Set[str] = set() + bts1: Set[int] = set() + bts2: Set[int] = set() if config.show_branches: for (lines, btset, sc) in [ @@ -1531,54 +2049,100 @@ def do_diff(basedump: str, mydump: str, config: Config) -> List[OutputLine]: for line in lines: bt = line.branch_target if bt is not None: - text = f"{bt}:" - btset.add(text) - sc(text) + btset.add(bt) + sc(str(bt)) - for (line1, line2) in diff_lines(lines1, lines2, config.algorithm): + diffed_lines = diff_lines(lines1, lines2, config.algorithm) + score = score_diff_lines(diffed_lines, config) + + line_num_base = -1 + line_num_offset = 0 + line_num_2to1 = {} + for (line1, line2) in diffed_lines: + if line1 is not None and line1.line_num is not None: + line_num_base = line1.line_num + line_num_offset = 0 + else: + line_num_offset += 1 + if line2 is not None and line2.line_num is not None: + line_num_2to1[line2.line_num] = (line_num_base, line_num_offset) + + for (line1, line2) in diffed_lines: line_color1 = line_color2 = sym_color = BasicFormat.NONE line_prefix = " " + is_data_ref = False out1 = Text() if not line1 else Text(pad_mnemonic(line1.original)) out2 = Text() if not line2 else Text(pad_mnemonic(line2.original)) if line1 and line2 and line1.diff_row == line2.diff_row: - if line1.normalized_original == line2.normalized_original: + if line1.diff_row == "": + if line1.normalized_original != line2.normalized_original: + line_prefix = "i" + sym_color = BasicFormat.DIFF_CHANGE + out1 = out1.reformat(sym_color) + out2 = out2.reformat(sym_color) + is_data_ref = True + elif ( + line1.normalized_original == line2.normalized_original + and line2.branch_target is None + ): + # Fast path: no coloring needed. We don't include branch instructions + # in this case because we need to check that their targets line up in + # the diff, and don't just happen to have the are the same address + # by accident. pass elif line1.diff_row == "": + # Don't draw attention to differing branch-likely delay slots: they + # typically mirror the branch destination - 1 so the real difference + # is elsewhere. Still, do mark them as different to avoid confusion. + # No need to consider branches because delay slots can't branch. out1 = out1.reformat(BasicFormat.DELAY_SLOT) out2 = out2.reformat(BasicFormat.DELAY_SLOT) else: mnemonic = line1.original.split()[0] - branch1 = branch2 = Text() + branchless1, address1 = out1.plain(), "" + branchless2, address2 = out2.plain(), "" if mnemonic in arch.instructions_with_address_immediates: - out1, branch1 = map(Text, split_off_branch(out1.plain())) - out2, branch2 = map(Text, split_off_branch(out2.plain())) - branchless1 = out1.plain() - branchless2 = out2.plain() + branchless1, address1 = split_off_address(branchless1) + branchless2, address2 = split_off_address(branchless2) + + out1 = Text(branchless1) + out2 = Text(branchless2) out1, out2 = format_fields( arch.re_imm, out1, out2, lambda _: BasicFormat.IMMEDIATE ) - same_relative_target = False - if line1.branch_target is not None and line2.branch_target is not None: - relative_target1 = eval_line_num( - line1.branch_target - ) - eval_line_num(line1.line_num) - relative_target2 = eval_line_num( - line2.branch_target - ) - eval_line_num(line2.line_num) - same_relative_target = relative_target1 == relative_target2 + if line2.branch_target is not None: + target = line2.branch_target + line2_target = line_num_2to1.get(line2.branch_target) + if line2_target is None: + # If the target is outside the disassembly, extrapolate. + # This only matters near the bottom. + assert line2.line_num is not None + line2_line = line_num_2to1[line2.line_num] + line2_target = (line2_line[0] + (target - line2.line_num), 0) - if not same_relative_target and branch1.plain() != branch2.plain(): - branch1 = branch1.reformat(BasicFormat.IMMEDIATE) - branch2 = branch2.reformat(BasicFormat.IMMEDIATE) + # Set the key for three-way diffing to a normalized version. + norm2, norm_branch2 = split_off_address(line2.normalized_original) + if norm_branch2 != "": + line2.normalized_original = norm2 + str(line2_target) + same_target = line2_target == (line1.branch_target, 0) + else: + # Do a naive comparison for non-branches (e.g. function calls). + same_target = address1 == address2 - out1 += branch1 - out2 += branch2 if normalize_imms(branchless1, arch) == normalize_imms( branchless2, arch ): - if not same_relative_target: - # only imms differences + if imm_matches_everything(branchless2, arch): + # ignore differences due to %lo(.rodata + ...) vs symbol + out1 = out1.reformat(BasicFormat.NONE) + out2 = out2.reformat(BasicFormat.NONE) + elif line2.branch_target is not None and same_target: + # same-target branch, don't color + pass + else: + # must have an imm difference (or else we would have hit the + # fast path) sym_color = BasicFormat.IMMEDIATE line_prefix = "i" else: @@ -1592,10 +2156,17 @@ def do_diff(basedump: str, mydump: str, config: Config) -> List[OutputLine]: sym_color = BasicFormat.STACK line_prefix = "s" else: - # regs differences and maybe imms as well + # reg differences and maybe imm as well out1, out2 = format_fields(arch.re_reg, out1, out2, sc1, sc2) line_color1 = line_color2 = sym_color = BasicFormat.REGISTER line_prefix = "r" + + if same_target: + address_imm_fmt = BasicFormat.NONE + else: + address_imm_fmt = BasicFormat.IMMEDIATE + out1 += Text(address1, address_imm_fmt) + out2 += Text(address2, address_imm_fmt) elif line1 and line2: line_prefix = "|" line_color1 = line_color2 = sym_color = BasicFormat.DIFF_CHANGE @@ -1619,25 +2190,25 @@ def do_diff(basedump: str, mydump: str, config: Config) -> List[OutputLine]: out: Text, line: Optional[Line], line_color: Format, - btset: Set[str], + btset: Set[int], sc: FormatFunction, ) -> Optional[Text]: if line is None: return None + if line.line_num is None: + return out in_arrow = Text(" ") out_arrow = Text() if config.show_branches: if line.line_num in btset: - in_arrow = Text("~>", sc(line.line_num)) + in_arrow = Text("~>", sc(str(line.line_num))) if line.branch_target is not None: - out_arrow = " " + Text("~>", sc(line.branch_target + ":")) - return ( - Text(line.line_num, line_color) + " " + in_arrow + " " + out + out_arrow - ) + out_arrow = " " + Text("~>", sc(str(line.branch_target))) + formatted_line_num = Text(hex(line.line_num)[2:] + ":", line_color) + return formatted_line_num + " " + in_arrow + " " + out + out_arrow part1 = format_part(out1, line1, line_color1, bts1, sc5) part2 = format_part(out2, line2, line_color2, bts2, sc6) - key2 = line2.original if line2 else None if line2: for source_line in line2.source_lines: @@ -1666,21 +2237,62 @@ def do_diff(basedump: str, mydump: str, config: Config) -> List[OutputLine]: ) except: pass + padding = " " * 7 if config.show_line_numbers else " " * 2 output.append( OutputLine( - None, - " " + Text(source_line, line_format), - source_line, + base=None, + fmt2=padding + Text(source_line, line_format), + key2=source_line, + boring=True, + is_data_ref=False, + line1=None, + line2=None, ) ) - fmt2 = Text(line_prefix, sym_color) + " " + (part2 or Text()) - output.append(OutputLine(part1, fmt2, key2)) + key2 = line2.normalized_original if line2 else None + boring = False + if line_prefix == " ": + boring = True + elif config.compress and config.compress.same_instr and line_prefix in "irs": + boring = True - return output + if config.show_line_numbers: + if line2 and line2.source_line_num is not None: + num_color = ( + BasicFormat.SOURCE_LINE_NUM + if sym_color == BasicFormat.NONE + else sym_color + ) + num2 = Text(f"{line2.source_line_num:5}", num_color) + else: + num2 = Text(" " * 5) + else: + num2 = Text() + + fmt2 = Text(line_prefix, sym_color) + num2 + " " + (part2 or Text()) + + output.append( + OutputLine( + base=part1, + fmt2=fmt2, + key2=key2, + boring=boring, + is_data_ref=is_data_ref, + line1=line1, + line2=line2, + ) + ) + + return Diff(lines=output, score=score) -def chunk_diff(diff: List[OutputLine]) -> List[Union[List[OutputLine], OutputLine]]: +def chunk_diff_lines( + diff: List[OutputLine], +) -> List[Union[List[OutputLine], OutputLine]]: + """Chunk a diff into an alternating list like A B A B ... A, where: + * A is a List[OutputLine] of insertions, + * B is a single non-insertion OutputLine, with .base != None.""" cur_right: List[OutputLine] = [] chunks: List[Union[List[OutputLine], OutputLine]] = [] for output_line in diff: @@ -1694,62 +2306,109 @@ def chunk_diff(diff: List[OutputLine]) -> List[Union[List[OutputLine], OutputLin return chunks -def format_diff( - old_diff: List[OutputLine], new_diff: List[OutputLine], config: Config -) -> Tuple[Optional[Tuple[str, ...]], List[Tuple[str, ...]]]: - fmt = config.formatter - old_chunks = chunk_diff(old_diff) - new_chunks = chunk_diff(new_diff) - output: List[Tuple[Text, OutputLine, OutputLine]] = [] - assert len(old_chunks) == len(new_chunks), "same target" - empty = OutputLine(Text(), Text(), None) - for old_chunk, new_chunk in zip(old_chunks, new_chunks): - if isinstance(old_chunk, list): - assert isinstance(new_chunk, list) - if not old_chunk and not new_chunk: - # Most of the time lines sync up without insertions/deletions, - # and there's no interdiffing to be done. - continue - differ = difflib.SequenceMatcher(a=old_chunk, b=new_chunk, autojunk=False) - for (tag, i1, i2, j1, j2) in differ.get_opcodes(): - if tag in ["equal", "replace"]: - for i, j in zip(range(i1, i2), range(j1, j2)): - output.append((Text(), old_chunk[i], new_chunk[j])) - if tag in ["insert", "replace"]: - for j in range(j1 + i2 - i1, j2): - output.append((Text(), empty, new_chunk[j])) - if tag in ["delete", "replace"]: - for i in range(i1 + j2 - j1, i2): - output.append((Text(), old_chunk[i], empty)) - else: - assert isinstance(new_chunk, OutputLine) - assert new_chunk.base - # old_chunk.base and new_chunk.base have the same text since - # both diffs are based on the same target, but they might - # differ in color. Use the new version. - output.append((new_chunk.base, old_chunk, new_chunk)) +def compress_matching( + li: List[Tuple[OutputLine, ...]], context: int +) -> List[Tuple[OutputLine, ...]]: + ret: List[Tuple[OutputLine, ...]] = [] + matching_streak: List[Tuple[OutputLine, ...]] = [] + context = max(context, 0) - # TODO: status line, with e.g. approximate permuter score? - header_line: Optional[Tuple[str, ...]] - diff_lines: List[Tuple[str, ...]] - if config.threeway: - header_line = ("TARGET", " CURRENT", " PREVIOUS") - diff_lines = [ - ( - fmt.apply(base), - fmt.apply(new.fmt2), - fmt.apply(old.fmt2) or "-" if old != new else "", + def flush_matching() -> None: + if len(matching_streak) <= 2 * context + 1: + ret.extend(matching_streak) + else: + ret.extend(matching_streak[:context]) + skipped = len(matching_streak) - 2 * context + filler = OutputLine( + base=Text(f"<{skipped} lines>", BasicFormat.SOURCE_OTHER), + fmt2=Text(), + key2=None, + boring=False, + is_data_ref=False, + line1=None, + line2=None, ) - for (base, old, new) in output + columns = len(matching_streak[0]) + ret.append(tuple([filler] * columns)) + if context > 0: + ret.extend(matching_streak[-context:]) + matching_streak.clear() + + for line in li: + if line[0].boring: + matching_streak.append(line) + else: + flush_matching() + ret.append(line) + + flush_matching() + return ret + + +def align_diffs( + old_diff: Diff, new_diff: Diff, config: Config +) -> Tuple[TableMetadata, List[Tuple[OutputLine, ...]]]: + meta: TableMetadata + diff_lines: List[Tuple[OutputLine, ...]] + padding = " " * 7 if config.show_line_numbers else " " * 2 + + if config.threeway: + meta = TableMetadata( + headers=( + Text("TARGET"), + Text(f"{padding}CURRENT ({new_diff.score})"), + Text(f"{padding}PREVIOUS ({old_diff.score})"), + ), + current_score=new_diff.score, + previous_score=old_diff.score, + ) + old_chunks = chunk_diff_lines(old_diff.lines) + new_chunks = chunk_diff_lines(new_diff.lines) + diff_lines = [] + empty = OutputLine(Text(), Text(), None, True, False, None, None) + assert len(old_chunks) == len(new_chunks), "same target" + for old_chunk, new_chunk in zip(old_chunks, new_chunks): + if isinstance(old_chunk, list): + assert isinstance(new_chunk, list) + if not old_chunk and not new_chunk: + # Most of the time lines sync up without insertions/deletions, + # and there's no interdiffing to be done. + continue + differ = difflib.SequenceMatcher( + a=old_chunk, b=new_chunk, autojunk=False + ) + for (tag, i1, i2, j1, j2) in differ.get_opcodes(): + if tag in ["equal", "replace"]: + for i, j in zip(range(i1, i2), range(j1, j2)): + diff_lines.append((empty, new_chunk[j], old_chunk[i])) + if tag in ["insert", "replace"]: + for j in range(j1 + i2 - i1, j2): + diff_lines.append((empty, new_chunk[j], empty)) + if tag in ["delete", "replace"]: + for i in range(i1 + j2 - j1, i2): + diff_lines.append((empty, empty, old_chunk[i])) + else: + assert isinstance(new_chunk, OutputLine) + # old_chunk.base and new_chunk.base have the same text since + # both diffs are based on the same target, but they might + # differ in color. Use the new version. + diff_lines.append((new_chunk, new_chunk, old_chunk)) + diff_lines = [ + (base, new, old if old != new else empty) for base, new, old in diff_lines ] else: - header_line = None - diff_lines = [ - (fmt.apply(base), fmt.apply(new.fmt2)) - for (base, old, new) in output - if base or new.key2 is not None - ] - return header_line, diff_lines + meta = TableMetadata( + headers=( + Text("TARGET"), + Text(f"{padding}CURRENT ({new_diff.score})"), + ), + current_score=new_diff.score, + previous_score=None, + ) + diff_lines = [(line, line) for line in new_diff.lines] + if config.compress: + diff_lines = compress_matching(diff_lines, config.compress.context) + return meta, diff_lines def debounced_fs_watch( @@ -1758,10 +2417,10 @@ def debounced_fs_watch( config: Config, project: ProjectSettings, ) -> None: - import watchdog.events # type: ignore - import watchdog.observers # type: ignore + import watchdog.events + import watchdog.observers - class WatchEventHandler(watchdog.events.FileSystemEventHandler): # type: ignore + class WatchEventHandler(watchdog.events.FileSystemEventHandler): def __init__( self, queue: "queue.Queue[float]", file_targets: List[str] ) -> None: @@ -1830,35 +2489,45 @@ def debounced_fs_watch( class Display: basedump: str mydump: str + last_refresh_key: object config: Config emsg: Optional[str] - last_diff_output: Optional[List[OutputLine]] - pending_update: Optional[Tuple[str, bool]] + last_diff_output: Optional[Diff] + pending_update: Optional[str] ready_queue: "queue.Queue[None]" watch_queue: "queue.Queue[Optional[float]]" less_proc: "Optional[subprocess.Popen[bytes]]" def __init__(self, basedump: str, mydump: str, config: Config) -> None: self.config = config - self.basedump = basedump + self.base_lines = process(basedump, config) self.mydump = mydump self.emsg = None + self.last_refresh_key = None self.last_diff_output = None - def run_diff(self) -> str: + def run_diff(self) -> Tuple[str, object]: if self.emsg is not None: - return self.emsg + return (self.emsg, self.emsg) - diff_output = do_diff(self.basedump, self.mydump, self.config) + my_lines = process(self.mydump, self.config) + diff_output = do_diff(self.base_lines, my_lines, self.config) last_diff_output = self.last_diff_output or diff_output if self.config.threeway != "base" or not self.last_diff_output: self.last_diff_output = diff_output - header, diff_lines = format_diff(last_diff_output, diff_output, self.config) - return self.config.formatter.table(header, diff_lines[self.config.skip_lines :]) - def run_less(self) -> "Tuple[subprocess.Popen[bytes], subprocess.Popen[bytes]]": - output = self.run_diff() + meta, diff_lines = align_diffs(last_diff_output, diff_output, self.config) + diff_lines = diff_lines[self.config.skip_lines :] + output = self.config.formatter.table(meta, diff_lines) + refresh_key = ( + [[col.key2 for col in x[1:]] for x in diff_lines], + diff_output.score, + ) + return (output, refresh_key) + def run_less( + self, output: str + ) -> "Tuple[subprocess.Popen[bytes], subprocess.Popen[bytes]]": # Pipe the output through 'tail' and only then to less, to ensure the # write call doesn't block. ('tail' has to buffer all its input before # it starts writing.) This also means we don't have to deal with pipe @@ -1875,7 +2544,8 @@ class Display: return (buffer_proc, less_proc) def run_sync(self) -> None: - proca, procb = self.run_less() + output, _ = self.run_diff() + proca, procb = self.run_less(output) procb.wait() proca.wait() @@ -1883,12 +2553,14 @@ class Display: self.watch_queue = watch_queue self.ready_queue = queue.Queue() self.pending_update = None - dthread = threading.Thread(target=self.display_thread) + output, refresh_key = self.run_diff() + self.last_refresh_key = refresh_key + dthread = threading.Thread(target=self.display_thread, args=(output,)) dthread.start() self.ready_queue.get() - def display_thread(self) -> None: - proca, procb = self.run_less() + def display_thread(self, initial_output: str) -> None: + proca, procb = self.run_less(initial_output) self.less_proc = procb self.ready_queue.put(None) while True: @@ -1900,14 +2572,9 @@ class Display: os.system("tput reset") if ret != 0 and self.pending_update is not None: # killed by program with the intent to refresh - msg, error = self.pending_update + output = self.pending_update self.pending_update = None - if not error: - self.mydump = msg - self.emsg = None - else: - self.emsg = msg - proca, procb = self.run_less() + proca, procb = self.run_less(output) self.less_proc = procb self.ready_queue.put(None) else: @@ -1925,7 +2592,17 @@ class Display: if not error and not self.emsg and text == self.mydump: self.progress("Unchanged. ") return - self.pending_update = (text, error) + if not error: + self.mydump = text + self.emsg = None + else: + self.emsg = text + output, refresh_key = self.run_diff() + if refresh_key == self.last_refresh_key: + self.progress("Unchanged. ") + return + self.last_refresh_key = refresh_key + self.pending_update = output if not self.less_proc: return self.less_proc.kill() @@ -1995,8 +2672,8 @@ def main() -> None: display = Display(basedump, mydump, config) - if args.no_pager or args.format == "html": - print(display.run_diff()) + if args.no_pager or args.format in ("html", "json"): + print(display.run_diff()[0]) elif not args.watch: display.run_sync() else: