"""
Author: clueless (discord: object_object69)

Copyright (C) 2025

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

DEBUG = False
SKIP_VTABLE_CHECK = False

import sys
import zlib

from enum import StrEnum
from io import BytesIO
from pathlib import Path
from struct import unpack, pack
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing_extensions import Buffer


class Endianness(StrEnum):
    NATIVE = "="
    LITTLE = "<"
    BIG = ">"


def decode_pointer(ptr: int) -> tuple[int, int]:
    offset = ptr & 0xFFFFFFF
    section = ptr >> 24
    return offset, section


class BinaryReader(BytesIO):
    def __init__(
        self,
        initial_bytes: "Buffer" = ...,
        rsc_header=None,
        endian: Endianness = Endianness.NATIVE,
    ):
        super().__init__(initial_bytes=initial_bytes)
        self.rsc_header = rsc_header
        self.endian = endian

    def read_bool(self, endian: Endianness | None = None) -> bool:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}?", self.read(1))[0]

    def read_bytes(self, n: int, endian: Endianness | None = None) -> tuple[int, ...]:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}{n}b", self.read(1 * n))

    def read_unsigned_bytes(
        self, n: int, endian: Endianness | None = None
    ) -> tuple[int, ...]:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}{n}B", self.read(1 * n))

    def read_char(self, endian: Endianness | None = None) -> str:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}c", self.read(1))[0]

    def read_chars(self, n: int, endian: Endianness | None = None) -> tuple[bytes, ...]:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}{n}c", self.read(n * 1))

    def read_char_array(self, n: int, endian: Endianness | None = None) -> bytes:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}{n}s", self.read(n * 1))[0]

    def read_float(self, endian: Endianness | None = None) -> float:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}f", self.read(4))[0]

    def read_floats(
        self, n: int, endian: Endianness | None = None
    ) -> tuple[float, ...]:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}{n}f", self.read(4 * n))

    def read_int16(self, endian: Endianness | None = None) -> int:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}h", self.read(2))[0]

    def read_int32(self, endian: Endianness | None = None) -> int:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}i", self.read(4))[0]

    def read_uint16(self, endian: Endianness | None = None) -> int:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}H", self.read(2))[0]

    def read_uint32(self, endian: Endianness | None = None) -> int:
        if endian is None:
            endian = self.endian
        return unpack(f"{endian}I", self.read(4))[0]

    def follow_pointer(self, pointer: int):
        offset, section = decode_pointer(pointer)
        if section == 0x50:
            self.seek(offset)
        elif section == 0x60:
            if self.rsc_header is None:
                raise Exception("rsc_header is not set")
            self.seek(offset + self.rsc_header.virtual_size)
        else:
            raise ValueError(f"Unknown pointer section {hex(section)}")


def unlock_model_collection(reader: BinaryReader) -> None:
    model_ptr_arr_ptr = reader.read_uint32()
    num_models = reader.read_uint16()
    if DEBUG:
        print(f"{num_models=}")
    reader.follow_pointer(model_ptr_arr_ptr)
    model_arr = [reader.read_uint32() for _ in range(num_models)]

    for model_ptr in model_arr:
        reader.follow_pointer(model_ptr)
        reader.seek(4, 1)
        geometry_ptr_arr_ptr = reader.read_uint32()
        num_geometries = reader.read_uint16()
        if DEBUG:
            print(f"{num_geometries=}")
        reader.follow_pointer(geometry_ptr_arr_ptr)
        geometry_arr = [reader.read_uint32() for _ in range(num_geometries)]

        for geometry_offset in geometry_arr:
            reader.follow_pointer(geometry_offset)
            reader.seek(44, 1)
            index_count = reader.read_uint32()
            face_offset = reader.tell()
            face_count = reader.read_uint32()
            vertex_count = reader.read_uint16()
            if DEBUG:
                print(f"{index_count=}, {face_count=}, {vertex_count=}")
            reader.seek(face_offset)
            if DEBUG:
                print(f"{face_count} => {index_count // 3}")
            reader.write(pack("I", index_count // 3))


def unlock_drawable(reader: BinaryReader) -> None:
    if not SKIP_VTABLE_CHECK:
        vtable = reader.read_uint32()
        if vtable != 0x695254 and vtable != 0xB8AAF60D and vtable != 0x49C27D95:
            print(vtable)
            print(f"Invalid drawable @ {reader.tell()-4}, skipping...")
            return None
    else:
        reader.seek(4, 1)

    reader.seek(60, 1)
    high_model_collection_ptr = reader.read_uint32()
    if DEBUG:
        print(f"{high_model_collection_ptr=:#x}")
    med_model_collection_ptr = reader.read_uint32()
    if DEBUG:
        print(f"{med_model_collection_ptr=:#x}")
    low_model_collection_ptr = reader.read_uint32()
    if DEBUG:
        print(f"{low_model_collection_ptr=:#x}")
    vlow_model_collection_ptr = reader.read_uint32()
    if DEBUG:
        print(f"{vlow_model_collection_ptr=:#x}")

    if high_model_collection_ptr:
        if DEBUG:
            print("Unlocking High Model")
        reader.follow_pointer(high_model_collection_ptr)
        unlock_model_collection(reader)
    if med_model_collection_ptr:
        if DEBUG:
            print("Unlocking Med Model")
        reader.follow_pointer(med_model_collection_ptr)
        unlock_model_collection(reader)
    if low_model_collection_ptr:
        if DEBUG:
            print("Unlocking Low Model")
        reader.follow_pointer(low_model_collection_ptr)
        unlock_model_collection(reader)
    if vlow_model_collection_ptr:
        if DEBUG:
            print("Unlocking VLow Model")
        reader.follow_pointer(vlow_model_collection_ptr)
        unlock_model_collection(reader)
    return None


def _open_rsc_file(
    filepath: Path, endian: Endianness = Endianness.NATIVE
) -> tuple[bytes, bytes | None, BinaryReader | None]:
    with open(filepath, "rb") as f:
        rsc_magic = f.read(4)
        if rsc_magic != b"RSC\x05":
            print(f"{filepath.absolute()} is not a RSC resource file, skipping...")
            return rsc_magic, None, None
        sizes = f.read(8)
        reader = BinaryReader(zlib.decompress(f.read()), endian=endian)
    return rsc_magic, sizes, reader

def _write_rsc_data(rsc_magic: bytes, sizes: bytes, reader: BytesIO) -> BytesIO:
    output = BytesIO()
    output.write(rsc_magic)
    output.write(sizes)
    output.write(zlib.compress(reader.getbuffer(), level=zlib.Z_BEST_COMPRESSION))
    return output


def unlock_wdr(
    filepath: Path, endian: Endianness = Endianness.LITTLE
) -> BytesIO | None:
    rsc_magic, sizes, reader = _open_rsc_file(filepath, endian)
    if sizes is None or reader is None:
        return None

    unlock_drawable(reader)

    return _write_rsc_data(rsc_magic, sizes, reader)


def unlock_wdd(
    filepath: Path, endian: Endianness = Endianness.LITTLE
) -> BytesIO | None:
    rsc_magic, sizes, reader = _open_rsc_file(filepath, endian)
    if sizes is None or reader is None:
        return None

    if not SKIP_VTABLE_CHECK:
        wdd_vtable = reader.read_uint32()
        if wdd_vtable != 0x6953A4:
            print(f"{filepath.absolute()} does not contain WDD, skipping...")
            return None
    else:
        reader.seek(4, 1)

    reader.seek(20, 1)

    drawable_ptr_arr_ptr = reader.read_uint32()
    num_drawables = reader.read_uint16()
    reader.follow_pointer(drawable_ptr_arr_ptr)
    drawables_ptr_arr = [reader.read_uint32() for _ in range(num_drawables)]

    for drawable_ptr in drawables_ptr_arr:
        reader.follow_pointer(drawable_ptr)
        unlock_drawable(reader)

    return _write_rsc_data(rsc_magic, sizes, reader)


def unlock_wft(
    filepath: Path, endian: Endianness = Endianness.LITTLE
) -> BytesIO | None:
    rsc_magic, sizes, reader = _open_rsc_file(filepath, endian)
    if sizes is None or reader is None:
        return None

    if not SKIP_VTABLE_CHECK:
        wft_vtable = reader.read_uint32()
        if wft_vtable != 0x695238 and wft_vtable != 0x930A4609:
            print(f"{filepath.absolute()} does not contain WFT, skipping...")
            return None
    else:
        reader.seek(4, 1)

    reader.seek(176, 1)
    main_frag_drawable_offset = reader.read_uint32()
    reader.seek(28, 1)
    frag_group_ptr_arr_ptr = reader.read_uint32()

    reader.follow_pointer(main_frag_drawable_offset)
    unlock_drawable(reader)

    reader.follow_pointer(frag_group_ptr_arr_ptr)
    frag_group_ptr_arr = []
    val = reader.read_uint32()
    while val != 0xCDCDCDCD:
        frag_group_ptr_arr.append(val)
        val = reader.read_uint32()

    if DEBUG:
        print(f"{len(frag_group_ptr_arr)=}")

    for frag_group_ptr in frag_group_ptr_arr:
        reader.follow_pointer(frag_group_ptr)
        reader.seek(144, 1)
        reader.follow_pointer(reader.read_uint32())
        unlock_drawable(reader)

    return _write_rsc_data(rsc_magic, sizes, reader)


def main() -> None:
    print("USE AT YOUR OWN RISK!\n")

    global SKIP_VTABLE_CHECK, DEBUG
    args = sys.argv[1:]
    if "-skip-vtable-check" in args:
        SKIP_VTABLE_CHECK = True
    if "-d" in args or "-debug" in args:
        DEBUG = True

    print(f"{DEBUG = }", end="")
    if not DEBUG:
        print(" | Use -d or -debug flag to enable debug mode", end="")
    print()
    print(f"{SKIP_VTABLE_CHECK = }", end="")
    if not SKIP_VTABLE_CHECK:
        print(" | Use -skip-vtable-check flag to skip vtable checks", end="")
    print()

    input_folder = Path(input("Input Folder: "))
    output_folder = Path(input("Output Folder: "))
    output_folder.mkdir(exist_ok=True)

    for file in input_folder.glob("*.w[ddf][rdt]"):
        print(f"\033[33mProcessing {file}\033[0m")
        if not file.is_file():
            continue
        suffix = file.suffix
        if suffix == ".wdr":
            output = unlock_wdr(file)
        elif suffix == ".wdd":
            output = unlock_wdd(file)
        elif suffix == ".wft":
            output = unlock_wft(file)
        else:
            continue

        if output is None:
            continue
        output.seek(0)
        output_filepath = output_folder / file.name
        with open(output_filepath, "wb") as f:
            f.write(output.getbuffer())
        print(f"\033[32mSuccessfully written to {output_filepath.absolute()}\033[0m")


if __name__ == "__main__":
    main()
