#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2024 zeldaret
# SPDX-License-Identifier: CC0-1.0


from __future__ import annotations

import argparse
from collections import Counter
import colorama
from dataclasses import dataclass
import io
import multiprocessing
import multiprocessing.pool
from pathlib import Path
import re
import shlex
import sys
import time
import traceback
from typing import BinaryIO, Iterator, Optional, Tuple

from ido_block_numbers import (
    generate_make_log,
    find_compiler_command_line,
    run_cfe,
    SymbolTableEntry,
    UcodeOp,
)

import elftools.elf.elffile
import mapfile_parser.mapfile


# Set on program start since we replace sys.stdout in worker processes
stdout_isatty = sys.stdout.isatty()


def output(message: str = "", color: Optional[str] = None, end: str = "\n"):
    if color and stdout_isatty:
        print(f"{color}{message}{colorama.Fore.RESET}", end=end)
    else:
        print(message, end=end)


def read_u32(f: BinaryIO, offset: int) -> int:
    f.seek(offset)
    return int.from_bytes(f.read(4), "big")


def read_u16(f: BinaryIO, offset: int) -> int:
    f.seek(offset)
    return int.from_bytes(f.read(2), "big")


def read_s16(f: BinaryIO, offset: int) -> int:
    f.seek(offset)
    return int.from_bytes(f.read(2), "big", signed=True)


class FixBssException(Exception):
    pass


@dataclass
class Reloc:
    name: str
    offset_32: int | None
    offset_hi16: int | None
    offset_lo16: int | None
    addend: int


@dataclass
class Pointer:
    name: str
    addend: int
    base_value: int
    build_value: int


@dataclass
class BssSection:
    start_address: int
    pointers: list[Pointer]


# Read relocations from an ELF file section
def read_relocs(object_path: Path, section_name: str) -> list[Reloc]:
    with open(object_path, "rb") as f:
        elffile = elftools.elf.elffile.ELFFile(f)
        symtab = elffile.get_section_by_name(".symtab")
        data = elffile.get_section_by_name(section_name).data()

        reloc_section = elffile.get_section_by_name(f".rel{section_name}")
        if reloc_section is None:
            return []

        relocs = []
        offset_hi16 = 0
        for reloc in reloc_section.iter_relocations():
            reloc_offset = reloc.entry["r_offset"]
            reloc_type = reloc.entry["r_info_type"]
            reloc_name = symtab.get_symbol(reloc.entry["r_info_sym"]).name

            if reloc_type == 2:  # R_MIPS_32
                offset_32 = reloc_offset
                addend = int.from_bytes(
                    data[reloc_offset : reloc_offset + 4], "big", signed=True
                )
                relocs.append(Reloc(reloc_name, offset_32, None, None, addend))
            elif reloc_type == 4:  # R_MIPS_26
                pass
            elif reloc_type == 5:  # R_MIPS_HI16
                offset_hi16 = reloc_offset
            elif reloc_type == 6:  # R_MIPS_LO16
                offset_lo16 = reloc_offset
                addend_hi16 = int.from_bytes(
                    data[offset_hi16 + 2 : offset_hi16 + 4], "big", signed=False
                )
                addend_lo16 = int.from_bytes(
                    data[offset_lo16 + 2 : offset_lo16 + 4], "big", signed=True
                )
                addend = (addend_hi16 << 16) + addend_lo16
                relocs.append(Reloc(reloc_name, None, offset_hi16, offset_lo16, addend))
            else:
                raise NotImplementedError(f"Unsupported relocation type: {reloc_type}")

        return relocs


def get_file_pointers(
    file: mapfile_parser.mapfile.File,
    base: BinaryIO,
    build: BinaryIO,
) -> list[Pointer]:
    pointers = []
    # TODO: open each ELF file only once instead of once per section?
    for reloc in read_relocs(file.filepath, file.sectionType):
        if reloc.offset_32 is not None:
            base_value = read_u32(base, file.vrom + reloc.offset_32)
            build_value = read_u32(build, file.vrom + reloc.offset_32)
        elif reloc.offset_hi16 is not None and reloc.offset_lo16 is not None:
            if (
                read_u16(base, file.vrom + reloc.offset_hi16)
                != read_u16(build, file.vrom + reloc.offset_hi16)
            ) or (
                read_u16(base, file.vrom + reloc.offset_lo16)
                != read_u16(build, file.vrom + reloc.offset_lo16)
            ):
                raise FixBssException(
                    f"Reference to {reloc.name} in {file.filepath} is in a shifted or non-matching portion of the ROM.\n"
                    "Please ensure that the only differences between the baserom and the current build are due to BSS ordering."
                )

            base_value = (
                read_u16(base, file.vrom + reloc.offset_hi16 + 2) << 16
            ) + read_s16(base, file.vrom + reloc.offset_lo16 + 2)
            build_value = (
                read_u16(build, file.vrom + reloc.offset_hi16 + 2) << 16
            ) + read_s16(build, file.vrom + reloc.offset_lo16 + 2)
        else:
            assert False, "Invalid relocation"

        # For relocations against a global symbol, subtract the addend so that the pointer
        # is for the start of the symbol. This can help deal with things like STACK_TOP
        # (where the pointer is past the end of the symbol) or negative addends. If the
        # relocation is against a section however, it's not useful to subtract the addend,
        # so we keep it as-is and hope for the best.
        if reloc.name.startswith("."):  # section
            addend = reloc.addend
        else:  # symbol
            addend = 0
            base_value -= reloc.addend
            build_value -= reloc.addend

        pointers.append(Pointer(reloc.name, addend, base_value, build_value))
    return pointers


base = None
build = None


def get_file_pointers_worker_init(base_path: Path, build_path: Path):
    global base
    global build
    base = open(base_path, "rb")
    build = open(build_path, "rb")


def get_file_pointers_worker(file: mapfile_parser.mapfile.File) -> list[Pointer]:
    assert base is not None
    assert build is not None
    return get_file_pointers(file, base, build)


# Compare pointers between the baserom and the current build, returning a dictionary from
# C files to a list of pointers into their BSS sections
def compare_pointers(version: str) -> dict[Path, BssSection]:
    mapfile_path = Path(f"build/{version}/oot-{version}.map")
    base_path = Path(f"baseroms/{version}/baserom-decompressed.z64")
    build_path = Path(f"build/{version}/oot-{version}.z64")
    if not mapfile_path.exists():
        raise FixBssException(f"Could not open {mapfile_path}")
    if not base_path.exists():
        raise FixBssException(f"Could not open {base_path}")
    if not build_path.exists():
        raise FixBssException(f"Could not open {build_path}")

    mapfile = mapfile_parser.mapfile.MapFile()
    mapfile.readMapFile(mapfile_path)

    # Segments built from source code (filtering out assets)
    source_code_segments = []
    for mapfile_segment in mapfile:
        if not (
            mapfile_segment.name.startswith("..boot")
            or mapfile_segment.name.startswith("..code")
            or mapfile_segment.name.startswith("..n64dd")
            or mapfile_segment.name.startswith("..ovl_")
        ):
            continue
        source_code_segments.append(mapfile_segment)

    # Find all pointers with different values
    if not stdout_isatty:
        output(f"Comparing pointers between baserom and build ...")
    pointers = []
    file_results = []
    with multiprocessing.Pool(
        initializer=get_file_pointers_worker_init,
        initargs=(base_path, build_path),
    ) as p:
        for mapfile_segment in source_code_segments:
            for file in mapfile_segment:
                if not str(file.filepath).endswith(".o"):
                    continue
                if file.sectionType == ".bss":
                    continue
                file_result = p.apply_async(get_file_pointers_worker, (file,))
                file_results.append(file_result)

        # Report progress and wait until all files are done
        num_files = len(file_results)
        while True:
            time.sleep(0.010)
            num_files_done = sum(file_result.ready() for file_result in file_results)
            if stdout_isatty:
                output(
                    f"Comparing pointers between baserom and build ... {num_files_done:>{len(f'{num_files}')}}/{num_files}",
                    end="\r",
                )
            if num_files_done == num_files:
                break
        if stdout_isatty:
            output("")

        # Collect results and check for errors
        for file_result in file_results:
            try:
                pointers.extend(file_result.get())
            except FixBssException as e:
                output(f"Error: {str(e)}", color=colorama.Fore.RED)
                sys.exit(1)

    # Remove duplicates and sort by baserom address
    pointers = list({p.base_value: p for p in pointers}.values())
    pointers.sort(key=lambda p: p.base_value)

    # Go through sections and collect differences
    bss_sections = {}
    for mapfile_segment in source_code_segments:
        for file in mapfile_segment:
            if not file.sectionType == ".bss":
                continue

            pointers_in_section = [
                p
                for p in pointers
                if file.vram <= p.build_value < file.vram + file.size
            ]

            object_file = file.filepath.relative_to(f"build/{version}")
            # Hack to handle the combined z_message_z_game_over.o file.
            # Fortunately z_game_over has no BSS so we can just analyze z_message instead.
            if str(object_file) == "src/code/z_message_z_game_over.o":
                object_file = Path("src/code/z_message.o")

            c_file = object_file.with_suffix(".c")
            bss_sections[c_file] = BssSection(file.vram, pointers_in_section)

    return bss_sections


@dataclass
class Pragma:
    line_number: int
    block_number: int
    amount: int


# A BSS variable in the source code
@dataclass
class BssVariable:
    block_number: int
    name: str
    size: int
    align: int
    referenced_in_data: bool


# A BSS variable with its offset in the compiled .bss section
@dataclass
class BssSymbol:
    name: str
    offset: int
    size: int
    align: int
    referenced_in_data: bool


INCREMENT_BLOCK_NUMBER_RE = re.compile(r"increment_block_number_(\d+)_(\d+)")


# Find increment_block_number pragmas by parsing the symbol names generated by preprocess.py.
# This is pretty ugly but it seems more reliable than trying to determine the line numbers of
# BSS variables in the C file.
def find_pragmas(symbol_table: list[SymbolTableEntry]) -> list[Pragma]:
    # Keep track of first block number and count for each line number
    first_block_number = {}
    amounts: Counter[int] = Counter()
    for block_number, entry in enumerate(symbol_table):
        if match := INCREMENT_BLOCK_NUMBER_RE.match(entry.name):
            line_number = int(match.group(1))
            if line_number not in first_block_number:
                first_block_number[line_number] = block_number
            amounts[line_number] += 1

    pragmas = []
    for line_number, block_number in sorted(first_block_number.items()):
        pragmas.append(Pragma(line_number, block_number, amounts[line_number]))
    return pragmas


# Find all BSS variables from IDO's symbol table and U-Code output.
def find_bss_variables(
    symbol_table: list[SymbolTableEntry], ucode: list[UcodeOp]
) -> list[BssVariable]:
    bss_variables = []
    init_block_numbers = set(op.i1 for op in ucode if op.opcode_name == "init")
    last_function_name = None
    # Block numbers referenced in .data or .rodata (in order of appearance)
    referenced_in_data_block_numbers = []

    for op in ucode:
        # gsym: file-level global symbol
        # lsym: file-level static symbol
        # fsym: function-level static symbol
        if op.opcode_name in ("gsym", "lsym", "fsym"):
            block_number = op.i1
            if block_number in init_block_numbers:
                continue  # not BSS

            name = symbol_table[block_number].name
            if op.opcode_name == "fsym":
                name = f"{last_function_name}::{name}"

            size = op.args[0]
            align = 1 << op.lexlev
            # TODO: IDO seems to automatically align anything with size 8 or more to
            # an 8-byte boundary in BSS. Is this correct?
            if size >= 8:
                align = 8

            referenced_in_data = block_number in referenced_in_data_block_numbers
            bss_variables.append(
                BssVariable(block_number, name, size, align, referenced_in_data)
            )
        elif op.opcode_name == "init":
            if op.dtype == 10:  # Ndt, "non-local label"
                assert op.const is not None
                referenced_in_data_block_numbers.append(op.const)
        elif op.opcode_name == "ent":
            last_function_name = symbol_table[op.i1].name

    # Sort any variables referenced in .data or .rodata first. For the others, sort by block number
    # so it looks like the original ordering in the source code (it doesn't matter since
    # predict_bss_ordering will sort them again anyway.
    def sort_key(var: BssVariable) -> Tuple[int, int]:
        if var.referenced_in_data:
            index = referenced_in_data_block_numbers.index(var.block_number)
        else:
            index = len(referenced_in_data_block_numbers)
        return (index, var.block_number)

    bss_variables.sort(key=sort_key)
    return bss_variables


# Predict offsets of BSS variables in the build.
def predict_bss_ordering(variables: list[BssVariable]) -> list[BssSymbol]:
    bss_symbols = []
    offset = 0

    # For variables referenced in .data or .rodata, keep the original order.
    referenced_in_data = [var for var in variables if var.referenced_in_data]

    # For the others, sort by block number mod 256. For ties, sort by block number.
    not_referenced_in_data = [var for var in variables if not var.referenced_in_data]
    not_referenced_in_data.sort(
        key=lambda var: (var.block_number % 256, var.block_number)
    )

    sorted_variables = referenced_in_data + not_referenced_in_data
    for var in sorted_variables:
        size = var.size
        align = var.align
        offset = (offset + align - 1) & ~(align - 1)
        bss_symbols.append(
            BssSymbol(var.name, offset, size, align, var.referenced_in_data)
        )
        offset += size
    return bss_symbols


# Match up BSS variables between the baserom and the build using the pointers from relocations.
# Note that we may not be able to match all variables if a variable is not referenced by any pointer.
def determine_base_bss_ordering(
    build_bss_symbols: list[BssSymbol],
    bss_section: BssSection,
) -> list[BssSymbol]:
    base_start_address = min(p.base_value for p in bss_section.pointers)

    found_symbols: dict[str, BssSymbol] = {}
    for p in bss_section.pointers:
        base_offset = p.base_value - base_start_address
        build_offset = p.build_value - bss_section.start_address

        new_symbol = None
        new_offset = 0
        for symbol in build_bss_symbols:
            if (
                symbol.offset <= build_offset
                and build_offset < symbol.offset + symbol.size
            ):
                new_symbol = symbol
                new_offset = base_offset - (build_offset - symbol.offset)
                break

        if new_symbol is None:
            if p.addend > 0:
                addend_str = f"+0x{p.addend:X}"
            elif p.addend < 0:
                addend_str = f"-0x{-p.addend:X}"
            else:
                addend_str = ""
            raise FixBssException(
                f"Could not find BSS symbol for pointer {p.name}{addend_str} "
                f"(base address 0x{p.base_value:08X}, build address 0x{p.build_value:08X}). Is the build up-to-date?"
            )

        if new_offset < 0:
            raise FixBssException(
                f"BSS symbol {new_symbol.name} found at negative offset in the baserom "
                f"(-0x{-new_offset:04X}). Is the build up-to-date?"
            )

        if new_symbol.name in found_symbols:
            # Sanity check that offsets agree
            existing_offset = found_symbols[new_symbol.name].offset
            if new_offset != existing_offset:
                raise FixBssException(
                    f"BSS symbol {new_symbol.name} found at conflicting offsets in the baserom "
                    f"(0x{existing_offset:04X} and 0x{new_offset:04X}). Is the build up-to-date?"
                )
        else:
            found_symbols[new_symbol.name] = BssSymbol(
                new_symbol.name,
                new_offset,
                new_symbol.size,
                new_symbol.align,
                new_symbol.referenced_in_data,
            )

    return list(sorted(found_symbols.values(), key=lambda symbol: symbol.offset))


# Generate a sequence of integers in the range [0, 256) with a 2-adic valuation of exactly `nu`.
# The 2-adic valuation of an integer n is the largest k such that 2^k divides n
# (see https://en.wikipedia.org/wiki/P-adic_valuation), and for convenience we define
# the 2-adic valuation of 0 to be 8. Here's what the sequences look like for nu = 0..8:
#   8: 0
#   7: 128
#   6: 64, 192
#   5: 32, 96, 160, 224
#   4: 16, 48, 80, 112, ...
#   3: 8, 24, 40, 56, ...
#   2: 4, 12, 20, 28, ...
#   1: 2, 6, 10, 14, ...
#   0: 1, 3, 5, 7, ...
def gen_seq(nu: int) -> Iterator[int]:
    if nu == 8:
        yield 0
    else:
        for i in range(1 << (7 - nu)):
            yield (2 * i + 1) * (1 << nu)


# Yields all n-tuples of integers in the range [0, 256) with minimum 2-adic valuation
# of exactly `min_nu`.
def gen_candidates_impl(n: int, min_nu: int) -> Iterator[tuple[int, ...]]:
    if n == 1:
        for n in gen_seq(min_nu):
            yield (n,)
    else:
        # (a, *b) has min 2-adic valuation = min_nu if and only if either:
        #   a has 2-adic valuation >  min_nu and b has min 2-adic valuation == min_nu
        #   a has 2-adic valuation == min_nu and b has min 2-adic valuation >= min_nu
        for min_nu_a in reversed(range(min_nu + 1, 9)):
            for a in gen_seq(min_nu_a):
                for b in gen_candidates_impl(n - 1, min_nu):
                    yield (a, *b)
        for a in gen_seq(min_nu):
            for min_nu_b in reversed(range(min_nu, 9)):
                for b in gen_candidates_impl(n - 1, min_nu_b):
                    yield (a, *b)


# Yields all n-tuples of integers in the range [0, 256), ordered by descending minimum
# 2-adic valuation of the elements in the tuple. For example, for n = 2 the sequence is:
#   (0, 0), (0, 128), (128, 0), (128, 128), (0, 64), (0, 192), (128, 64), (128, 192), ...
def gen_candidates(n: int) -> Iterator[tuple[int, ...]]:
    for nu in reversed(range(9)):
        yield from gen_candidates_impl(n, nu)


# Determine a new set of increment_block_number pragmas that will fix the BSS ordering.
def solve_bss_ordering(
    pragmas: list[Pragma],
    bss_variables: list[BssVariable],
    base_bss_symbols: list[BssSymbol],
) -> list[Pragma]:
    base_symbols_by_name = {symbol.name: symbol for symbol in base_bss_symbols}

    # Our "algorithm" just tries all possible combinations of increment_block_number amounts,
    # which can get very slow with more than a few pragmas. But, we order the candidates in a
    # binary-search-esque way to try to find a solution faster.
    for new_amounts in gen_candidates(len(pragmas)):
        # Generate new block numbers
        new_bss_variables = []
        for var in bss_variables:
            new_block_number = var.block_number
            for pragma, new_amount in zip(pragmas, new_amounts):
                if var.block_number >= pragma.block_number:
                    new_block_number += new_amount - pragma.amount
            new_bss_variables.append(
                BssVariable(
                    new_block_number,
                    var.name,
                    var.size,
                    var.align,
                    var.referenced_in_data,
                )
            )

        # Predict new BSS and check if new ordering matches
        new_bss_symbols = predict_bss_ordering(new_bss_variables)

        bss_ordering_matches = True
        for symbol in new_bss_symbols:
            base_symbol = base_symbols_by_name.get(symbol.name)
            if base_symbol is None:
                continue
            if symbol.offset != base_symbol.offset:
                bss_ordering_matches = False
                break

        if bss_ordering_matches:
            new_pragmas = []
            for pragma, new_amount in zip(pragmas, new_amounts):
                new_pragmas.append(
                    Pragma(pragma.line_number, pragma.block_number, new_amount)
                )
            return new_pragmas

    raise FixBssException("Could not find any solutions")


# Parses #pragma increment_block_number (with line continuations already removed)
def parse_pragma(pragma_string: str) -> dict[str, int]:
    amounts = {}
    for part in pragma_string.replace('"', "").split()[2:]:
        kv = part.split(":")
        if len(kv) != 2:
            raise FixBssException(
                "#pragma increment_block_number"
                f' arguments must be version:amount pairs, not "{part}"'
            )
        try:
            amount = int(kv[1])
        except ValueError:
            raise FixBssException(
                "#pragma increment_block_number"
                f' amount must be an integer, not "{kv[1]}" (in "{part}")'
            )
        amounts[kv[0]] = amount
    return amounts


# Formats #pragma increment_block_number as a list of lines
def format_pragma(amounts: dict[str, int], max_line_length: int) -> list[str]:
    lines = []
    pragma_start = "#pragma increment_block_number "
    current_line = pragma_start + '"'
    first = True
    for version, amount in sorted(amounts.items()):
        part = f"{version}:{amount}"
        if len(current_line) + len(" ") + len(part) + len('" \\') > max_line_length:
            lines.append(current_line + '" ')
            current_line = " " * len(pragma_start) + '"'
            first = True
        if not first:
            current_line += " "
        current_line += part
        first = False
    lines.append(current_line + '"\n')

    if len(lines) >= 2:
        # add and align vertically all continuation \ characters
        n_align = max(map(len, lines[:-1]))
        for i in range(len(lines) - 1):
            lines[i] = f"{lines[i]:{n_align}}\\\n"

    return lines


def update_source_file(version_to_update: str, file: Path, new_pragmas: list[Pragma]):
    with open(file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    replace_lines: list[tuple[int, int, list[str]]] = []

    for pragma in new_pragmas:
        i = pragma.line_number - 1
        if not lines[i].startswith("#pragma increment_block_number"):
            raise FixBssException(
                f"Expected #pragma increment_block_number on line {pragma.line_number}"
            )

        # list the pragma line and any continuation line
        pragma_lines = [lines[i]]
        while pragma_lines[-1].endswith("\\\n"):
            i += 1
            pragma_lines.append(lines[i])

        # concatenate all lines into one
        pragma_string = "".join(s.replace("\\\n", "") for s in pragma_lines)

        amounts = parse_pragma(pragma_string)

        amounts[version_to_update] = pragma.amount

        column_limit = 120  # matches .clang-format's ColumnLimit
        new_pragma_lines = format_pragma(amounts, column_limit)

        replace_lines.append(
            (
                pragma.line_number - 1,
                pragma.line_number - 1 + len(pragma_lines),
                new_pragma_lines,
            )
        )

    # Replace the pragma lines starting from the end of the file, so the line numbers
    # for pragmas earlier in the file stay accurate.
    replace_lines.sort(key=lambda it: it[0], reverse=True)
    for start, end, new_pragma_lines in replace_lines:
        del lines[start:end]
        lines[start:start] = new_pragma_lines

    with open(file, "w", encoding="utf-8") as f:
        f.writelines(lines)


def process_file(
    file: Path,
    bss_section: BssSection,
    make_log: list[str],
    dry_run: bool,
    version: str,
):
    output(f"Processing {file} ...", color=colorama.Fore.CYAN)

    command_line = find_compiler_command_line(make_log, file)
    if command_line is None:
        raise FixBssException(f"Could not determine compiler command line for {file}")

    output(f"Compiler command: {shlex.join(command_line)}")
    symbol_table, ucode = run_cfe(command_line, keep_files=False)

    bss_variables = find_bss_variables(symbol_table, ucode)
    output("BSS variables:")
    for var in bss_variables:
        i = var.block_number
        output(
            f"  {i:>6} [{i%256:>3}]: size=0x{var.size:04X} align=0x{var.align:X} referenced_in_data={str(var.referenced_in_data):<5} {var.name}"
        )

    build_bss_symbols = predict_bss_ordering(bss_variables)
    output("Current build BSS ordering:")
    for symbol in build_bss_symbols:
        output(
            f"  offset=0x{symbol.offset:04X} size=0x{symbol.size:04X} align=0x{symbol.align:X} referenced_in_data={str(symbol.referenced_in_data):<5} {symbol.name}"
        )

    if not bss_section.pointers:
        raise FixBssException(f"No pointers to BSS found in ROM for {file}")

    base_bss_symbols = determine_base_bss_ordering(build_bss_symbols, bss_section)
    output("Baserom BSS ordering:")
    for symbol in base_bss_symbols:
        output(
            f"  offset=0x{symbol.offset:04X} size=0x{symbol.size:04X} align=0x{symbol.align:X} referenced_in_data={str(symbol.referenced_in_data):<5} {symbol.name}"
        )

    pragmas = find_pragmas(symbol_table)
    max_pragmas = 3
    if not pragmas:
        raise FixBssException(f"No increment_block_number pragmas found in {file}")
    elif len(pragmas) > max_pragmas:
        raise FixBssException(
            f"Too many increment_block_number pragmas found in {file} (found {len(pragmas)}, max {max_pragmas})"
        )

    output("Solving BSS ordering ...")
    new_pragmas = solve_bss_ordering(pragmas, bss_variables, base_bss_symbols)
    output("New increment_block_number amounts:")
    for pragma in new_pragmas:
        output(f"  line {pragma.line_number}: {pragma.amount}")

    if not dry_run:
        update_source_file(version, file, new_pragmas)
        output(f"Updated {file}", color=colorama.Fore.GREEN)


def process_file_worker(*x):
    # Collect output in a buffer to avoid interleaving output when processing multiple files
    old_stdout = sys.stdout
    fake_stdout = io.StringIO()
    try:
        sys.stdout = fake_stdout
        process_file(*x)
    except FixBssException as e:
        # exception with a message for the user
        output(f"Error: {str(e)}", color=colorama.Fore.RED)
        raise
    except Exception as e:
        # "unexpected" exception, also print a trace for devs
        output(f"Error: {str(e)}", color=colorama.Fore.RED)
        traceback.print_exc(file=sys.stdout)
        raise
    finally:
        sys.stdout = old_stdout
        output()
        output(fake_stdout.getvalue(), end="")


def main():
    parser = argparse.ArgumentParser(
        description="Automatically fix BSS ordering by editing increment_block_number pragmas. "
        "Assumes that the build is up-to-date and that only differences between the baserom and "
        "the current build are due to BSS ordering."
    )
    parser.add_argument(
        "-v",
        "--version",
        dest="oot_version",
        type=str,
        required=True,
        help="OOT version",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print changes instead of editing source files",
    )
    parser.add_argument(
        "files",
        metavar="FILE",
        nargs="*",
        type=Path,
        help="Fix BSS ordering for a particular C file (default: all files with BSS differences)",
    )

    args = parser.parse_args()
    version = args.oot_version

    bss_sections = compare_pointers(version)

    files_with_reordering = []
    for file, bss_section in bss_sections.items():
        if not bss_section.pointers:
            continue
        # The following heuristic doesn't work for z_locale, since the first pointer into BSS is not
        # at the start of the section. Fortunately z_locale either has one BSS variable (in GC versions)
        # or none (in N64 versions), so we can just skip it.
        if str(file) == "src/boot/z_locale.c":
            continue
        # For the baserom, assume that the lowest address is the start of the BSS section. This might
        # not be true if the first BSS variable is not referenced, but in practice this doesn't happen
        # (except for z_locale above).
        base_min_address = min(p.base_value for p in bss_section.pointers)
        build_min_address = bss_section.start_address
        if not all(
            p.build_value - build_min_address == p.base_value - base_min_address
            for p in bss_section.pointers
        ):
            files_with_reordering.append(file)

    if files_with_reordering:
        output("Files with BSS reordering:")
        for file in files_with_reordering:
            output(f"  {file}")
    else:
        output("No BSS reordering found.")

    if args.files:
        # Ignore files that don't have a BSS section in the ROM
        files_to_fix = [file for file in args.files if file in bss_sections]
    else:
        files_to_fix = files_with_reordering
    if not files_to_fix:
        return

    output(f"Running make to find compiler command line ...")
    make_log = generate_make_log(version)

    with multiprocessing.Pool() as p:
        file_results = []
        for file in files_to_fix:
            file_result = p.apply_async(
                process_file_worker,
                (
                    file,
                    bss_sections[file],
                    make_log,
                    args.dry_run,
                    version,
                ),
            )
            file_results.append(file_result)

        # Wait until all files are done
        while not all(file_result.ready() for file_result in file_results):
            time.sleep(0.010)

        # Collect results and check for errors
        num_successes = sum(file_result.successful() for file_result in file_results)
        if num_successes == len(file_results):
            output()
            output(f"Processed {num_successes}/{len(file_results)} files.")
        else:
            output()
            output(
                f"Processed {num_successes}/{len(file_results)} files.",
                color=colorama.Fore.RED,
            )
            sys.exit(1)


if __name__ == "__main__":
    main()