# credits: https://gtamods.com/wiki/IMG_archive, https://github.com/ahmed605/SparkIV/tree/master/SRC/RageLib/FileSystem/IMG

import struct
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path

from cryptography import decrypt_aes


@dataclass(init=False)
class Header:
    identifier: int
    version: int
    num_items: int
    table_size: int
    entry_size: int
    unk: int

    # @classmethod
    def read(self, data: bytes):
        (
            self.identifier,
            self.version,
            self.num_items,
            self.table_size,
            self.entry_size,
            self.unk,
        ) = struct.unpack("<4I2H", data)
        return self


@dataclass(init=False)
class Item:
    name: str
    size: int
    resource_type: int
    block_offset: int
    num_blocks_used: int
    # padding? flags?

    def __init__(self, data: bytes):
        self.name = ""
        self.size, self.resource_type, self.block_offset, self.num_blocks_used, _ = (
            struct.unpack("<I2i2h", data)
        )
        # if self.resource_type in ResourceType:
        #     self.resource_type = ResourceType(self.resource_type)
        # else:
        #     print("[Warn] resource_type is invalid for item with block_offset: ", self.block_offset)


@dataclass(init=False)
class Table:
    items: list[Item] = field(default_factory=list)

    def __init__(self, data: bytes, header: Header):
        self.items: list[Item] = []
        for i in range(header.num_items):
            start = 16 * i
            self.items.append(Item(data[start : start + 16]))

        nametable_size = header.table_size - (header.num_items * header.entry_size)
        start = header.num_items * 16
        nametable_data = data[start : start + nametable_size].strip()
        for i, name in enumerate(nametable_data.split(b"\x00")[:-1]):
            self.items[i].name = name.decode()

        if any(not item.name for item in self.items):
            raise ValueError("Name not found")


@dataclass(init=False)
class IMG3Archive:
    filepath: Path
    encrypted: bool
    header: Header
    table: Table

    def __init__(self, filepath: str | Path, aes_key: str | bytes):
        self.encrypted = False
        self.filepath = Path(filepath)
        with open(filepath, "rb") as f:
            identifier = f.read(4)
            if struct.unpack("<I", identifier)[0] != 2840472146:
                self.encrypted = True

            f.seek(0)
            if self.encrypted:
                header_data = decrypt_aes(f.read(20), aes_key)
            else:
                header_data = f.read(20)

            self.header = Header().read(header_data)

            # TODO: Replace these assertions for proper handling
            assert self.header.identifier == 2840472146
            assert self.header.version == 3

            table_data = f.read(self.header.table_size)
            if self.encrypted:
                table_data = decrypt_aes(table_data, aes_key)

            self.table = Table(table_data, self.header)

    def search(self, item_name: str) -> list[Item]:
        result = []
        for item in self.table.items:
            if item_name in item.name:
                result.append(item)
        return result

    def get_item(self, item_name: str) -> Item | None:
        for item in self.table.items:
            if item.name == item_name:
                return item
        return None

    def __contains__(self, item_name: str) -> bool:
        return self.search(item_name) is not None

    def extract(self, item_name: str, output_filepath: Path | str) -> bool:
        item = self.get_item(item_name)

        if item is None:
            return False

        with self.filepath.open("rb") as reader:
            reader.seek(item.block_offset * 2048)
            with open(output_filepath, "wb") as writer:
                writer.write(reader.read(item.num_blocks_used * 2048))

        return True

    def open(self, item_name: str) -> BytesIO | None:
        item = self.get_item(item_name)
        if item is None:
            return None
        buffer = None

        with self.filepath.open("rb") as reader:
            reader.seek(item.block_offset * 2048)
            buffer = BytesIO(reader.read(item.num_blocks_used * 2048))

        return buffer
