import mch22
import buttons
import display
import os
import sys
import struct
import sndmixer

class WADFile:
    def __init__(self, path):
        self.path = path
        self.ident = None
        self.numlumps = 0
        self.infotableofs = 0
        self.directory = dict()   # list of (filepos, size, name)
        self._parse()

    def _parse(self):
        with open(self.path, "rb") as f:
            # --- Header: 4s, uint32, uint32 ---
            self.ident = f.read(4).decode("ascii", "ignore")
            if self.ident not in ("IWAD", "PWAD"):
                raise ValueError("Not a WAD file")

            self.numlumps = int.from_bytes(f.read(4), "little")
            self.infotableofs = int.from_bytes(f.read(4), "little")

            # --- Directory entries: numlumps * 16 bytes ---
            f.seek(self.infotableofs)
            for _ in range(self.numlumps):
                filepos = int.from_bytes(f.read(4), "little")
                size    = int.from_bytes(f.read(4), "little")
                name    = f.read(8).rstrip(b"\0 ").decode("ascii", "ignore")
                self.directory[name] = (filepos, size)

    def extract_sound(self, name):
        name = name.upper()

        if name in self.directory.keys():
            entry = self.directory[name]
            with open(self.path, "rb") as f:
                f.seek(entry[0])
                data = f.read(entry[1])

            fmt = data[0]

            if fmt in (0, 1, 2, 3):
                # Doom Sound Format Index
                samplerate = 11025 if fmt == 3 else 8000
                samplecount = int.from_bytes(data[2:4], "little")
                pcm = data[4:4+samplecount]
            else:
                # Raw PC Doom sound
                samplerate = int.from_bytes(data[0:2], "little")
                samplecount = int.from_bytes(data[2:4], "little")
                pcm = data[4:4+samplecount]

            return samplerate, pcm

        raise KeyError("Sound lump not found")

sound_table = [
    ("none",    False,  0,   0, -1, -1, 0),
    ("pistol",  False, 64,   0, -1, -1, 0),
    ("shotgn",  False, 64,   0, -1, -1, 0),
    ("sgcock",  False, 64,   0, -1, -1, 0),
    ("dshtgn",  False, 64,   0, -1, -1, 0),
    ("dbopn",   False, 64,   0, -1, -1, 0),
    ("dbcls",   False, 64,   0, -1, -1, 0),
    ("dbload",  False, 64,   0, -1, -1, 0),
    ("plasma",  False, 64,   0, -1, -1, 0),
    ("bfg",     False, 64,   0, -1, -1, 0),
    ("sawup",   False, 64,   0, -1, -1, 0),
    ("sawidl",  False,118,   0, -1, -1, 0),
    ("sawful",  False, 64,   0, -1, -1, 0),
    ("sawhit",  False, 64,   0, -1, -1, 0),
    ("rlaunc",  False, 64,   0, -1, -1, 0),
    ("rxplod",  False, 70,   0, -1, -1, 0),
    ("firsht",  False, 70,   0, -1, -1, 0),
    ("firxpl",  False, 70,   0, -1, -1, 0),
    ("pstart",  False,100,   0, -1, -1, 0),
    ("pstop",   False,100,   0, -1, -1, 0),
    ("doropn",  False,100,   0, -1, -1, 0),
    ("dorcls",  False,100,   0, -1, -1, 0),
    ("stnmov",  False,119,   0, -1, -1, 0),
    ("swtchn",  False, 78,   0, -1, -1, 0),
    ("swtchx",  False, 78,   0, -1, -1, 0),
    ("plpain",  False, 96,   0, -1, -1, 0),
    ("dmpain",  False, 96,   0, -1, -1, 0),
    ("popain",  False, 96,   0, -1, -1, 0),
    ("vipain",  False, 96,   0, -1, -1, 0),
    ("mnpain",  False, 96,   0, -1, -1, 0),
    ("pepain",  False, 96,   0, -1, -1, 0),
    ("slop",    False, 78,   0, -1, -1, 0),
    ("itemup",  True,  78,   0, -1, -1, 0),
    ("wpnup",   True,  78,   0, -1, -1, 0),
    ("oof",     False, 96,   0, -1, -1, 0),
    ("telept",  False, 32,   0, -1, -1, 0),
    ("posit1",  True,  98,   0, -1, -1, 0),
    ("posit2",  True,  98,   0, -1, -1, 0),
    ("posit3",  True,  98,   0, -1, -1, 0),
    ("bgsit1",  True,  98,   0, -1, -1, 0),
    ("bgsit2",  True,  98,   0, -1, -1, 0),
    ("sgtsit",  True,  98,   0, -1, -1, 0),
    ("cacsit",  True,  98,   0, -1, -1, 0),
    ("brssit",  True,  94,   0, -1, -1, 0),
    ("cybsit",  True,  92,   0, -1, -1, 0),
    ("spisit",  True,  90,   0, -1, -1, 0),
    ("bspsit",  True,  90,   0, -1, -1, 0),
    ("kntsit",  True,  90,   0, -1, -1, 0),
    ("vilsit",  True,  90,   0, -1, -1, 0),
    ("mansit",  True,  90,   0, -1, -1, 0),
    ("pesit",   True,  90,   0, -1, -1, 0),
    ("sklatk",  False, 70,   0, -1, -1, 0),
    ("sgtatk",  False, 70,   0, -1, -1, 0),
    ("skepch",  False, 70,   0, -1, -1, 0),
    ("vilatk",  False, 70,   0, -1, -1, 0),
    ("claw",    False, 70,   0, -1, -1, 0),
    ("skeswg",  False, 70,   0, -1, -1, 0),
    ("pldeth",  False, 32,   0, -1, -1, 0),
    ("pdiehi",  False, 32,   0, -1, -1, 0),
    ("podth1",  False, 70,   0, -1, -1, 0),
    ("podth2",  False, 70,   0, -1, -1, 0),
    ("podth3",  False, 70,   0, -1, -1, 0),
    ("bgdth1",  False, 70,   0, -1, -1, 0),
    ("bgdth2",  False, 70,   0, -1, -1, 0),
    ("sgtdth",  False, 70,   0, -1, -1, 0),
    ("cacdth",  False, 70,   0, -1, -1, 0),
    ("skldth",  False, 70,   0, -1, -1, 0),
    ("brsdth",  False, 32,   0, -1, -1, 0),
    ("cybdth",  False, 32,   0, -1, -1, 0),
    ("spidth",  False, 32,   0, -1, -1, 0),
    ("bspdth",  False, 32,   0, -1, -1, 0),
    ("vildth",  False, 32,   0, -1, -1, 0),
    ("kntdth",  False, 32,   0, -1, -1, 0),
    ("pedth",   False, 32,   0, -1, -1, 0),
    ("skedth",  False, 32,   0, -1, -1, 0),
    ("posact",  True, 120,   0, -1, -1, 0),
    ("bgact",   True, 120,   0, -1, -1, 0),
    ("dmact",   True, 120,   0, -1, -1, 0),
    ("bspact",  True, 100,   0, -1, -1, 0),
    ("bspwlk",  True, 100,   0, -1, -1, 0),
    ("vilact",  True, 100,   0, -1, -1, 0),
    ("noway",   False, 78,   0, -1, -1, 0),
    ("barexp",  False, 60,   0, -1, -1, 0),
    ("punch",   False, 64,   0, -1, -1, 0),
    ("hoof",    False, 70,   0, -1, -1, 0),
    ("metal",   False, 70,   0, -1, -1, 0),
    ("chgun",   False, 64,   0, -1, -1, 0),
    ("tink",    False, 60,   0, -1, -1, 0),
    ("bdopn",   False,100,   0, -1, -1, 0),
    ("bdcls",   False,100,   0, -1, -1, 0),
    ("itmbk",   False,100,   0, -1, -1, 0),
    ("flame",   False, 32,   0, -1, -1, 0),
    ("flamst",  False, 32,   0, -1, -1, 0),
    ("getpow",  False, 60,   0, -1, -1, 0),
    ("bospit",  False, 70,   0, -1, -1, 0),
    ("boscub",  False, 70,   0, -1, -1, 0),
    ("bossit",  False, 70,   0, -1, -1, 0),
    ("bospn",   False, 70,   0, -1, -1, 0),
    ("bosdth",  False, 70,   0, -1, -1, 0),
    ("manatk",  False, 70,   0, -1, -1, 0),
    ("mandth",  False, 70,   0, -1, -1, 0),
    ("sssit",   False, 70,   0, -1, -1, 0),
    ("ssdth",   False, 70,   0, -1, -1, 0),
    ("keenpn",  False, 70,   0, -1, -1, 0),
    ("keendt",  False, 70,   0, -1, -1, 0),
    ("skeact",  False, 70,   0, -1, -1, 0),
    ("skesit",  False, 70,   0, -1, -1, 0),
    ("skeatk",  False, 70,   0, -1, -1, 0),
    ("radio",   False, 60,   0, -1, -1, 0),
]

def raw_u8_mono_to_wav(raw: bytearray, sample_rate: int) -> bytes:
    """
    Convert unsigned 8‑bit mono raw PCM audio to a WAV byte string
    without using the wave module.
    """

    num_channels = 1
    bits_per_sample = 8
    audio_format = 1  # PCM
    byte_rate = sample_rate * num_channels * bits_per_sample // 8
    block_align = num_channels * bits_per_sample // 8
    data_size = len(raw)

    riff_size = 4 + (8 + 16) + (8 + data_size)

    header = bytearray()

    header += b'RIFF'
    header += struct.pack('<I', riff_size)
    header += b'WAVE'

    header += b'fmt '
    header += struct.pack('<I', 16)
    header += struct.pack('<H', audio_format)
    header += struct.pack('<H', num_channels)
    header += struct.pack('<I', sample_rate)
    header += struct.pack('<I', byte_rate)
    header += struct.pack('<H', block_align)
    header += struct.pack('<H', bits_per_sample)

    header += b'data'
    header += struct.pack('<I', data_size)

    return bytes(header) + bytes(raw)

def load_sounds(wad, sound_table):
    sounds = []

    for entry in sound_table:
        name = entry[0]

        # "none" or missing sounds → return (0, b"")
        if name == "none":
            sounds.append((0, b""))
            continue

        lump = "DS" + name.upper()
        print(lump)

        try:
            rate, pcm = wad.extract_sound(lump)
        except KeyError:
            # Sound not found in WAD → return silence
            rate, pcm = 0, b""

        wav = raw_u8_mono_to_wav(pcm, rate)
        sounds.append(wav)

    return sounds

def play_sound_by_index(sounds, index):
    if index < len(sounds):
        wav = sounds[index]
        audioStreamId = sndmixer.wav(wav)
        sndmixer.play(audioStreamId)

# ---------- Path helpers ----------

def op_split(path):
    if not path:
        return ("", "")
    r = path.rsplit("/", 1)
    if len(r) == 1:
        return ("", path)
    head = r[0] or "/"
    return (head, r[1])

def getcwd():
    return op_split(__file__)[0]


# ---------- SPI protocol constants ----------

# Opcodes
OP_BTN_REPORT      = 0xF4
OP_FREAD_GET       = 0xF8
OP_FREAD_PUT       = 0xF9
OP_PLAY_SOUND_GET  = 0xFA
OP_RESP_ACK        = 0xFE
OP_NOP2            = 0xFF

# IRQ bits (from FPGA to ESP)
IRQ_FREAD          = 0x01
# reserve a bit for future play_sound IRQ, e.g.:
IRQ_PLAY_SOUND     = 0x02

# Max fread payload (0x400 bytes + 1 opcode)
FREAD_MAX_LEN      = 0x400


# ---------- Button handling ----------

g_btn_state = 0

def send_button_report(btn_mask, pressed):
    global g_btn_state

    if pressed:
        g_btn_state |= btn_mask
    else:
        g_btn_state &= ~btn_mask

    payload = bytearray(5)
    payload[0] = OP_BTN_REPORT
    payload[1] = (g_btn_state >> 8) & 0xFF
    payload[2] = g_btn_state & 0xFF
    payload[3] = (btn_mask >> 8) & 0xFF
    payload[4] = btn_mask & 0xFF

    mch22.fpga_send(bytes(payload))


def make_button_handler(mask):
    def handler(pressed):
        send_button_report(mask, pressed)
    return handler


def setup_buttons():
    # These masks match your original mapping
    buttons.attach(buttons.BTN_A,      make_button_handler(1 << 9))
    buttons.attach(buttons.BTN_B,      make_button_handler(1 << 10))
    buttons.attach(buttons.BTN_HOME,   make_button_handler(1 << 5))
    buttons.attach(buttons.BTN_MENU,   make_button_handler(1 << 6))
    buttons.attach(buttons.BTN_SELECT, make_button_handler(1 << 7))
    buttons.attach(buttons.BTN_START,  make_button_handler(1 << 8))
    buttons.attach(buttons.BTN_LEFT,   make_button_handler(1 << 2))
    buttons.attach(buttons.BTN_RIGHT,  make_button_handler(1 << 3))
    buttons.attach(buttons.BTN_UP,     make_button_handler(1 << 1))
    buttons.attach(buttons.BTN_DOWN,   make_button_handler(1 << 0))
    buttons.attach(buttons.BTN_PRESS,  make_button_handler(1 << 4))


# ---------- FPGA server core ----------

class FpgaFileServer:
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.file_ids = {}   # file_id -> file object or None
        self._closed = False

        # Prebuild common SPI command buffers
        self.spi_cmd_nop2      = bytearray(2)
        self.spi_cmd_nop2[0]   = OP_NOP2

        self.spi_cmd_fread_get = bytearray(1)
        self.spi_cmd_fread_get[0] = OP_FREAD_GET

        self.spi_cmd_resp_ack  = bytearray(12)
        self.spi_cmd_resp_ack[0] = OP_RESP_ACK

        self.spi_cmd_resp_sndack  = bytearray(3)
        self.spi_cmd_resp_sndack[0] = OP_RESP_ACK

        self.spi_cmd_fread_put = bytearray(FREAD_MAX_LEN + 1)
        self.spi_cmd_fread_put[0] = OP_FREAD_PUT
        
        self.spi_cmd_play_sound_get = bytearray(1)
        self.spi_cmd_play_sound_get[0] = OP_PLAY_SOUND_GET

    def close(self):
        if self._closed:
            return
        for f in self.file_ids.values():
            if f is not None:
                try:
                    f.close()
                except Exception:
                    pass
        self._closed = True

    def _get_file(self, file_id):
        if file_id in self.file_ids:
            return self.file_ids[file_id]

        filename = f'fpga_{file_id:08x}.dat'
        full_path = self.base_dir + "/" + filename

        if filename not in os.listdir(self.base_dir):
            self.file_ids[file_id] = None
            return None

        try:
            f = open(full_path, "rb")
        except OSError:
            f = None

        self.file_ids[file_id] = f
        return f

    def _handle_fread_irq(self):
        # Ask FPGA for fread request header
        mch22.fpga_send(bytes(self.spi_cmd_fread_get))
        buf = mch22.fpga_transaction(bytes(self.spi_cmd_resp_ack))

        # buf[0] == OP_RESP_ACK, buf[1] maybe status/reserved
        req_file_id = (buf[2] << 24) | (buf[3] << 16) | (buf[4] << 8) | buf[5]
        req_offset  = (buf[6] << 24) | (buf[7] << 16) | (buf[8] << 8) | buf[9]
        req_length  = ((buf[10] << 8) | buf[11]) + 1

        f = self._get_file(req_file_id)
        if f is None:
            data = bytearray(req_length)
        else:
            f.seek(req_offset, 0)
            data = f.read(req_length) or b""

        data_len = len(data)
        if data_len < req_length:
            # pad with zeros if short read
            data = data + bytes(req_length - data_len)
            data_len = req_length

        # Build and send response
        self.spi_cmd_fread_put[1:data_len + 1] = data
        mch22.fpga_send(bytes(self.spi_cmd_fread_put[:data_len + 1]))

    def _handle_play_sound_irq(self):
        global sounds
        
        # Ask FPGA for the sound_id
        mch22.fpga_send(bytes(self.spi_cmd_play_sound_get))
        resp = mch22.fpga_transaction(bytes(self.spi_cmd_resp_sndack))
        print(''.join('{:02x}'.format(x) for x in resp))
        sound_id = resp[2]

        print(f"[FPGA] play_sound request: sound_id={sound_id:02x}")
        play_sound_by_index(sounds, sound_id)

    def loop(self):
        try:
            while True:
                # Wait for any IRQ
                while True:
                    _, irqs = mch22.fpga_transaction(bytes(self.spi_cmd_nop2))
                    if irqs:
                        break

#                print(f'irqs: {irqs:02x}')
                # Handle fread IRQ
                if irqs & IRQ_FREAD:
                    self._handle_fread_irq()

                if irqs & IRQ_PLAY_SOUND:
                    self._handle_play_sound_irq()
        finally:
            self.close()


# ---------- Top-level setup ----------

def setup(base):
    bitstream = base + "/bitstream.bin"
    print("Loading bitstream:", bitstream)

    with open(bitstream, "rb") as f:
        mch22.lcd_mode(1)
        mch22.fpga_load(f.read())

    setup_buttons()

    server = FpgaFileServer(base)
    server.loop()


base = getcwd()
display.drawFill(0x000000)
display.drawText(5, 70, "Parsing WAD file...", 0xFFFFFF, "dejavusans20")
display.flush()
wad = WADFile(base + "/fpga_01c4546d.dat")
display.drawText(5, 100, "Loading DOOM sounds...", 0xFFFFFF, "dejavusans20")
display.flush()
sounds = load_sounds(wad, sound_table)
display.drawText(5, 130, "Starting...", 0xFFFFFF, "dejavusans20")
display.flush()
sndmixer.begin(2, True)
setup(base)
