#!/usr/bin/env python3 # SPDX-FileCopyrightText: 2024 zeldaret # SPDX-License-Identifier: CC0-1.0 from __future__ import annotations import argparse from collections import Counter import colorama from dataclasses import dataclass import io import itertools import multiprocessing import multiprocessing.pool from pathlib import Path import re import shlex import sys import time from typing import BinaryIO, Iterator from ido_block_numbers import ( generate_make_log, find_compiler_command_line, run_cfe, SymbolTableEntry, UcodeOp, ) import elftools.elf.elffile import mapfile_parser.mapfile def read_u32(f: BinaryIO, offset: int) -> int: f.seek(offset) return int.from_bytes(f.read(4), "big") def read_u16(f: BinaryIO, offset: int) -> int: f.seek(offset) return int.from_bytes(f.read(2), "big") def read_s16(f: BinaryIO, offset: int) -> int: f.seek(offset) return int.from_bytes(f.read(2), "big", signed=True) class FixBssException(Exception): pass @dataclass class Reloc: name: str offset_32: int | None offset_hi16: int | None offset_lo16: int | None addend: int @dataclass class Pointer: name: str addend: int base_value: int build_value: int # Read relocations from an ELF file section def read_relocs(object_path: Path, section_name: str) -> list[Reloc]: with open(object_path, "rb") as f: elffile = elftools.elf.elffile.ELFFile(f) symtab = elffile.get_section_by_name(".symtab") data = elffile.get_section_by_name(section_name).data() reloc_section = elffile.get_section_by_name(f".rel{section_name}") if reloc_section is None: return [] relocs = [] offset_hi16 = 0 for reloc in reloc_section.iter_relocations(): reloc_offset = reloc.entry["r_offset"] reloc_type = reloc.entry["r_info_type"] reloc_name = symtab.get_symbol(reloc.entry["r_info_sym"]).name if reloc_type == 2: # R_MIPS_32 offset_32 = reloc_offset addend = int.from_bytes( data[reloc_offset : reloc_offset + 4], "big", signed=True ) relocs.append(Reloc(reloc_name, offset_32, None, None, addend)) elif reloc_type == 4: # R_MIPS_26 pass elif reloc_type == 5: # R_MIPS_HI16 offset_hi16 = reloc_offset elif reloc_type == 6: # R_MIPS_LO16 offset_lo16 = reloc_offset addend_hi16 = int.from_bytes( data[offset_hi16 + 2 : offset_hi16 + 4], "big", signed=False ) addend_lo16 = int.from_bytes( data[offset_lo16 + 2 : offset_lo16 + 4], "big", signed=True ) addend = (addend_hi16 << 16) + addend_lo16 relocs.append(Reloc(reloc_name, None, offset_hi16, offset_lo16, addend)) else: raise NotImplementedError(f"Unsupported relocation type: {reloc_type}") return relocs def get_file_pointers( file: mapfile_parser.mapfile.File, base: BinaryIO, build: BinaryIO, ) -> list[Pointer]: pointers = [] # TODO: open each ELF file only once instead of once per section? for reloc in read_relocs(file.filepath, file.sectionType): if reloc.offset_32 is not None: base_value = read_u32(base, file.vrom + reloc.offset_32) build_value = read_u32(build, file.vrom + reloc.offset_32) elif reloc.offset_hi16 is not None and reloc.offset_lo16 is not None: if ( read_u16(base, file.vrom + reloc.offset_hi16) != read_u16(build, file.vrom + reloc.offset_hi16) ) or ( read_u16(base, file.vrom + reloc.offset_lo16) != read_u16(build, file.vrom + reloc.offset_lo16) ): raise FixBssException( f"Reference to {reloc.name} in {file.filepath} is in a shifted or non-matching portion of the ROM.\n" "Please ensure that the only differences between the baserom and the current build are due to BSS ordering." ) base_value = ( read_u16(base, file.vrom + reloc.offset_hi16 + 2) << 16 ) + read_s16(base, file.vrom + reloc.offset_lo16 + 2) build_value = ( read_u16(build, file.vrom + reloc.offset_hi16 + 2) << 16 ) + read_s16(build, file.vrom + reloc.offset_lo16 + 2) else: assert False, "Invalid relocation" pointers.append(Pointer(reloc.name, reloc.addend, base_value, build_value)) return pointers base = None build = None def get_file_pointers_worker_init(version: str): global base global build base = open(f"baseroms/{version}/baserom-decompressed.z64", "rb") build = open(f"build/{version}/oot-{version}.z64", "rb") def get_file_pointers_worker(file: mapfile_parser.mapfile.File) -> list[Pointer]: assert base is not None assert build is not None return get_file_pointers(file, base, build) # Compare pointers between the baserom and the current build, returning a dictionary from # C files to a list of pointers into their BSS sections def compare_pointers(version: str) -> dict[Path, list[Pointer]]: mapfile_path = Path(f"build/{version}/oot-{version}.map") if not mapfile_path.exists(): raise FixBssException(f"Could not open {mapfile_path}") mapfile = mapfile_parser.mapfile.MapFile() mapfile.readMapFile(mapfile_path) # Segments built from source code (filtering out assets) source_code_segments = [] for mapfile_segment in mapfile: if not ( mapfile_segment.name.startswith("..boot") or mapfile_segment.name.startswith("..code") or mapfile_segment.name.startswith("..buffers") or mapfile_segment.name.startswith("..ovl_") ): continue source_code_segments.append(mapfile_segment) # Find all pointers with different values if not sys.stdout.isatty(): print(f"Comparing pointers between baserom and build ...") pointers = [] file_results = [] with multiprocessing.Pool( initializer=get_file_pointers_worker_init, initargs=(version,), ) as p: for mapfile_segment in source_code_segments: for file in mapfile_segment: if not str(file.filepath).endswith(".o"): continue if file.sectionType == ".bss": continue file_result = p.apply_async(get_file_pointers_worker, (file,)) file_results.append(file_result) # Report progress and wait until all files are done num_files = len(file_results) while True: time.sleep(0.010) num_files_done = sum(file_result.ready() for file_result in file_results) if sys.stdout.isatty(): print( f"Comparing pointers between baserom and build ... {num_files_done:>{len(f'{num_files}')}}/{num_files}", end="\r", ) if num_files_done == num_files: break if sys.stdout.isatty(): print("") # Collect results and check for errors for file_result in file_results: try: pointers.extend(file_result.get()) except FixBssException as e: print(f"{colorama.Fore.RED}Error: {str(e)}{colorama.Fore.RESET}") sys.exit(1) # Remove duplicates and sort by baserom address pointers = list({p.base_value: p for p in pointers}.values()) pointers.sort(key=lambda p: p.base_value) # Go through sections and collect differences pointers_by_file = {} for mapfile_segment in source_code_segments: for file in mapfile_segment: if not file.sectionType == ".bss": continue pointers_in_section = [ p for p in pointers if file.vram <= p.build_value < file.vram + file.size ] if not pointers_in_section: continue c_file = file.filepath.relative_to(f"build/{version}").with_suffix(".c") pointers_by_file[c_file] = pointers_in_section return pointers_by_file @dataclass class Pragma: line_number: int block_number: int amount: int # A BSS variable in the source code @dataclass class BssVariable: block_number: int name: str size: int align: int # A BSS variable with its offset in the compiled .bss section @dataclass class BssSymbol: name: str offset: int size: int align: int INCREMENT_BLOCK_NUMBER_RE = re.compile(r"increment_block_number_(\d+)_(\d+)") # Find increment_block_number pragmas by parsing the symbol names generated by preprocess.py. # This is pretty ugly but it seems more reliable than trying to determine the line numbers of # BSS variables in the C file. def find_pragmas(symbol_table: list[SymbolTableEntry]) -> list[Pragma]: # Keep track of first block number and count for each line number first_block_number = {} amounts: Counter[int] = Counter() for block_number, entry in enumerate(symbol_table): if match := INCREMENT_BLOCK_NUMBER_RE.match(entry.name): line_number = int(match.group(1)) if line_number not in first_block_number: first_block_number[line_number] = block_number amounts[line_number] += 1 pragmas = [] for line_number, block_number in sorted(first_block_number.items()): pragmas.append(Pragma(line_number, block_number, amounts[line_number])) return pragmas # Find all BSS variables from IDO's symbol table and U-Code output. def find_bss_variables( symbol_table: list[SymbolTableEntry], ucode: list[UcodeOp] ) -> list[BssVariable]: bss_variables = [] init_block_numbers = set(op.i1 for op in ucode if op.opcode_name == "init") last_function_name = None for op in ucode: # gsym: file-level global symbol # lsym: file-level static symbol # fsym: function-level static symbol if op.opcode_name in ("gsym", "lsym", "fsym"): block_number = op.i1 if block_number in init_block_numbers: continue # not BSS name = symbol_table[block_number].name if op.opcode_name == "fsym": name = f"{last_function_name}::{name}" size = op.args[0] align = 1 << op.lexlev # TODO: IDO seems to automatically align anything with size 8 or more to # an 8-byte boundary in BSS. Is this correct? if size >= 8: align = 8 bss_variables.append(BssVariable(block_number, name, size, align)) elif op.opcode_name == "ent": last_function_name = symbol_table[op.i1].name bss_variables.sort(key=lambda var: var.block_number) return bss_variables # Predict offsets of BSS variables in the build. def predict_bss_ordering(variables: list[BssVariable]) -> list[BssSymbol]: bss_symbols = [] offset = 0 # Sort by block number mod 256 (for ties, the original order is preserved) for var in sorted(variables, key=lambda var: var.block_number % 256): size = var.size align = var.align offset = (offset + align - 1) & ~(align - 1) bss_symbols.append(BssSymbol(var.name, offset, size, align)) offset += size return bss_symbols # Match up BSS variables between the baserom and the build using the pointers from relocations. # Note that we may not be able to match all variables if a variable is not referenced by any pointer. def determine_base_bss_ordering( build_bss_symbols: list[BssSymbol], pointers: list[Pointer] ) -> list[BssSymbol]: # Assume that the lowest address is the start of the BSS section base_section_start = min(p.base_value for p in pointers) build_section_start = min(p.build_value for p in pointers) found_symbols: dict[str, BssSymbol] = {} for p in pointers: base_offset = p.base_value - base_section_start build_offset = p.build_value - build_section_start new_symbol = None new_offset = 0 for symbol in build_bss_symbols: if ( symbol.offset <= build_offset and build_offset < symbol.offset + symbol.size ): new_symbol = symbol new_offset = base_offset - (build_offset - symbol.offset) break if new_symbol is None: if p.addend > 0: addend_str = f"+0x{p.addend:X}" elif p.addend < 0: addend_str = f"-0x{-p.addend:X}" else: addend_str = "" raise FixBssException( f"Could not find BSS symbol for pointer {p.name}{addend_str} " f"(base address 0x{p.base_value:08X}, build address 0x{p.build_value:08X})" ) if new_symbol.name in found_symbols: # Sanity check that offsets agree existing_offset = found_symbols[new_symbol.name].offset if new_offset != existing_offset: raise FixBssException( f"BSS symbol {new_symbol.name} found at conflicting offsets in this baserom " f"(0x{existing_offset:04X} and 0x{new_offset:04X}). Is the build up-to-date?" ) else: found_symbols[new_symbol.name] = BssSymbol( new_symbol.name, new_offset, new_symbol.size, new_symbol.align ) return list(sorted(found_symbols.values(), key=lambda symbol: symbol.offset)) # Generate a sequence of integers in the range [0, 256) with a 2-adic valuation of exactly `nu`. # The 2-adic valuation of an integer n is the largest k such that 2^k divides n # (see https://en.wikipedia.org/wiki/P-adic_valuation), and for convenience we define # the 2-adic valuation of 0 to be 8. Here's what the sequences look like for nu = 0..8: # 8: 0 # 7: 128 # 6: 64, 192 # 5: 32, 96, 160, 224 # 4: 16, 48, 80, 112, ... # 3: 8, 24, 40, 56, ... # 2: 4, 12, 20, 28, ... # 1: 2, 6, 10, 14, ... # 0: 1, 3, 5, 7, ... def gen_seq(nu: int) -> Iterator[int]: if nu == 8: yield 0 else: for i in range(1 << (7 - nu)): yield (2 * i + 1) * (1 << nu) # Yields all n-tuples of integers in the range [0, 256) with minimum 2-adic valuation # of exactly `min_nu`. def gen_candidates_impl(n: int, min_nu: int) -> Iterator[tuple[int, ...]]: if n == 1: for n in gen_seq(min_nu): yield (n,) else: # (a, *b) has min 2-adic valuation = min_nu if and only if either: # a has 2-adic valuation > min_nu and b has min 2-adic valuation == min_nu # a has 2-adic valuation == min_nu and b has min 2-adic valuation >= min_nu for min_nu_a in reversed(range(min_nu + 1, 9)): for a in gen_seq(min_nu_a): for b in gen_candidates_impl(n - 1, min_nu): yield (a, *b) for a in gen_seq(min_nu): for min_nu_b in reversed(range(min_nu, 9)): for b in gen_candidates_impl(n - 1, min_nu_b): yield (a, *b) # Yields all n-tuples of integers in the range [0, 256), ordered by descending minimum # 2-adic valuation of the elements in the tuple. For example, for n = 2 the sequence is: # (0, 0), (0, 128), (128, 0), (128, 128), (0, 64), (0, 192), (128, 64), (128, 192), ... def gen_candidates(n: int) -> Iterator[tuple[int, ...]]: for nu in reversed(range(9)): yield from gen_candidates_impl(n, nu) # Determine a new set of increment_block_number pragmas that will fix the BSS ordering. def solve_bss_ordering( pragmas: list[Pragma], bss_variables: list[BssVariable], base_bss_symbols: list[BssSymbol], ) -> list[Pragma]: base_symbols_by_name = {symbol.name: symbol for symbol in base_bss_symbols} # Our "algorithm" just tries all possible combinations of increment_block_number amounts, # which can get very slow with more than a few pragmas. But, we order the candidates in a # binary-search-esque way to try to find a solution faster. for new_amounts in gen_candidates(len(pragmas)): # Generate new block numbers new_bss_variables = [] for var in bss_variables: new_block_number = var.block_number for pragma, new_amount in zip(pragmas, new_amounts): if var.block_number >= pragma.block_number: new_block_number += new_amount - pragma.amount new_bss_variables.append( BssVariable(new_block_number, var.name, var.size, var.align) ) # Predict new BSS and check if new ordering matches new_bss_symbols = predict_bss_ordering(new_bss_variables) bss_ordering_matches = True for symbol in new_bss_symbols: base_symbol = base_symbols_by_name.get(symbol.name) if base_symbol is None: continue if symbol.offset != base_symbol.offset: bss_ordering_matches = False break if bss_ordering_matches: new_pragmas = [] for pragma, new_amount in zip(pragmas, new_amounts): new_pragmas.append( Pragma(pragma.line_number, pragma.block_number, new_amount) ) return new_pragmas raise FixBssException("Could not find any solutions") def update_source_file(version_to_update: str, file: Path, new_pragmas: list[Pragma]): with open(file, "r", encoding="utf-8") as f: lines = f.readlines() for pragma in new_pragmas: line = lines[pragma.line_number - 1] if not line.startswith("#pragma increment_block_number "): raise FixBssException( f"Expected #pragma increment_block_number on line {pragma.line_number}" ) # Grab pragma argument and remove quotes arg = line.strip()[len("#pragma increment_block_number ") + 1 : -1] amounts_by_version = {} for part in arg.split(): version, amount_str = part.split(":") amounts_by_version[version] = int(amount_str) amounts_by_version[version_to_update] = pragma.amount new_arg = " ".join( f"{version}:{amount}" for version, amount in amounts_by_version.items() ) new_line = f'#pragma increment_block_number "{new_arg}"\n' lines[pragma.line_number - 1] = new_line with open(file, "w", encoding="utf-8") as f: f.writelines(lines) def process_file( file: Path, pointers: list[Pointer], make_log: list[str], dry_run: bool, version: str, ): print(f"{colorama.Fore.CYAN}Processing {file} ...{colorama.Fore.RESET}") command_line = find_compiler_command_line(make_log, file) if command_line is None: raise FixBssException(f"Could not determine compiler command line for {file}") print(f"Compiler command: {shlex.join(command_line)}") symbol_table, ucode = run_cfe(command_line, keep_files=False) bss_variables = find_bss_variables(symbol_table, ucode) print("BSS variables:") for var in bss_variables: i = var.block_number print( f" {i:>6} [{i%256:>3}]: size=0x{var.size:04X} align=0x{var.align:X} {var.name}" ) build_bss_symbols = predict_bss_ordering(bss_variables) print("Current build BSS ordering:") for symbol in build_bss_symbols: print( f" offset=0x{symbol.offset:04X} size=0x{symbol.size:04X} align=0x{symbol.align:X} {symbol.name}" ) if not pointers: raise FixBssException(f"No pointers to BSS found in ROM for {file}") base_bss_symbols = determine_base_bss_ordering(build_bss_symbols, pointers) print("Baserom BSS ordering:") for symbol in base_bss_symbols: print( f" offset=0x{symbol.offset:04X} size=0x{symbol.size:04X} align=0x{symbol.align:X} {symbol.name}" ) pragmas = find_pragmas(symbol_table) max_pragmas = 3 if not pragmas: raise FixBssException(f"No increment_block_number pragmas found in {file}") elif len(pragmas) > max_pragmas: raise FixBssException( f"Too many increment_block_number pragmas found in {file} (found {len(pragmas)}, max {max_pragmas})" ) print("Solving BSS ordering ...") new_pragmas = solve_bss_ordering(pragmas, bss_variables, base_bss_symbols) print("New increment_block_number amounts:") for pragma in new_pragmas: print(f" line {pragma.line_number}: {pragma.amount}") if not dry_run: update_source_file(version, file, new_pragmas) print(f"{colorama.Fore.GREEN}Updated {file}{colorama.Fore.RESET}") def process_file_worker(*x): # Collect output in a buffer to avoid interleaving output when processing multiple files old_stdout = sys.stdout fake_stdout = io.StringIO() try: sys.stdout = fake_stdout process_file(*x) except Exception as e: print(f"{colorama.Fore.RED}Error: {str(e)}{colorama.Fore.RESET}") raise finally: sys.stdout = old_stdout print() print(fake_stdout.getvalue(), end="") def main(): parser = argparse.ArgumentParser( description="Automatically fix BSS ordering by editing increment_block_number pragmas. " "Assumes that the build is up-to-date and that only differences between the baserom and " "the current build are due to BSS ordering." ) parser.add_argument( "--oot-version", "-v", type=str, required=True, help="OOT version", ) parser.add_argument( "--dry-run", action="store_true", help="Print changes instead of editing source files", ) parser.add_argument( "files", metavar="FILE", nargs="*", type=Path, help="Fix BSS ordering for a particular C file (default: all files with BSS differences)", ) args = parser.parse_args() version = args.oot_version pointers_by_file = compare_pointers(version) files_with_reordering = [] for file, pointers in pointers_by_file.items(): # Try to detect if the section is shifted by comparing the lowest # address among any pointer into the section between base and build base_min_address = min(p.base_value for p in pointers) build_min_address = min(p.build_value for p in pointers) if not all( p.build_value - build_min_address == p.base_value - base_min_address for p in pointers ): files_with_reordering.append(file) if files_with_reordering: print("Files with BSS reordering:") for file in files_with_reordering: print(f" {file}") else: print("No BSS reordering found.") if args.files: files_to_fix = args.files else: files_to_fix = files_with_reordering if not files_to_fix: return print(f"Running make to find compiler command line ...") make_log = generate_make_log(version) with multiprocessing.Pool() as p: file_results = [] for file in files_to_fix: file_result = p.apply_async( process_file_worker, ( file, pointers_by_file.get(file, []), make_log, args.dry_run, version, ), ) file_results.append(file_result) # Wait until all files are done while not all(file_result.ready() for file_result in file_results): time.sleep(0.010) # Collect results and check for errors num_successes = sum(file_result.successful() for file_result in file_results) if num_successes == len(file_results): print() print(f"Updated {num_successes}/{len(file_results)} files.") else: print() print( f"{colorama.Fore.RED}Updated {num_successes}/{len(file_results)} files.{colorama.Fore.RESET}" ) sys.exit(1) if __name__ == "__main__": main()