import mch22
import buttons
import display
import os
import sys
import struct
import sndmixer
from .bluetooth_keyboard import BleHIDKeyboard
import time

class WADFile:
    def __init__(self, path):
        self.path = path
        self.ident = None
        self.numlumps = 0
        self.infotableofs = 0
        self.directory = dict()   # list of (filepos, size, name)
        self._parse()

    def _parse(self):
        with open(self.path, "rb") as f:
            # --- Header: 4s, uint32, uint32 ---
            self.ident = f.read(4).decode("ascii", "ignore")
            if self.ident not in ("IWAD", "PWAD"):
                raise ValueError("Not a WAD file")

            self.numlumps = int.from_bytes(f.read(4), "little")
            self.infotableofs = int.from_bytes(f.read(4), "little")

            # --- Directory entries: numlumps * 16 bytes ---
            f.seek(self.infotableofs)
            for _ in range(self.numlumps):
                filepos = int.from_bytes(f.read(4), "little")
                size    = int.from_bytes(f.read(4), "little")
                name    = f.read(8).rstrip(b"\0 ").decode("ascii", "ignore")
                self.directory[name] = (filepos, size)

    def extract_sound(self, name):
        name = name.upper()

        if name in self.directory.keys():
            entry = self.directory[name]
            with open(self.path, "rb") as f:
                f.seek(entry[0])
                data = f.read(entry[1])

            fmt = data[0]

            if fmt in (0, 1, 2, 3):
                # Doom Sound Format Index
                samplerate = 11025 if fmt == 3 else 8000
                samplecount = int.from_bytes(data[2:4], "little")
                pcm = data[4:4+samplecount]
            else:
                # Raw PC Doom sound
                samplerate = int.from_bytes(data[0:2], "little")
                samplecount = int.from_bytes(data[2:4], "little")
                pcm = data[4:4+samplecount]

            return samplerate, pcm

        raise KeyError("Sound lump not found")

sound_table = [
    ("none",    False,  0,   0, -1, -1, 0),
    ("pistol",  False, 64,   0, -1, -1, 0),
    ("shotgn",  False, 64,   0, -1, -1, 0),
    ("sgcock",  False, 64,   0, -1, -1, 0),
    ("dshtgn",  False, 64,   0, -1, -1, 0),
    ("dbopn",   False, 64,   0, -1, -1, 0),
    ("dbcls",   False, 64,   0, -1, -1, 0),
    ("dbload",  False, 64,   0, -1, -1, 0),
    ("plasma",  False, 64,   0, -1, -1, 0),
    ("bfg",     False, 64,   0, -1, -1, 0),
    ("sawup",   False, 64,   0, -1, -1, 0),
    ("sawidl",  False,118,   0, -1, -1, 0),
    ("sawful",  False, 64,   0, -1, -1, 0),
    ("sawhit",  False, 64,   0, -1, -1, 0),
    ("rlaunc",  False, 64,   0, -1, -1, 0),
    ("rxplod",  False, 70,   0, -1, -1, 0),
    ("firsht",  False, 70,   0, -1, -1, 0),
    ("firxpl",  False, 70,   0, -1, -1, 0),
    ("pstart",  False,100,   0, -1, -1, 0),
    ("pstop",   False,100,   0, -1, -1, 0),
    ("doropn",  False,100,   0, -1, -1, 0),
    ("dorcls",  False,100,   0, -1, -1, 0),
    ("stnmov",  False,119,   0, -1, -1, 0),
    ("swtchn",  False, 78,   0, -1, -1, 0),
    ("swtchx",  False, 78,   0, -1, -1, 0),
    ("plpain",  False, 96,   0, -1, -1, 0),
    ("dmpain",  False, 96,   0, -1, -1, 0),
    ("popain",  False, 96,   0, -1, -1, 0),
    ("vipain",  False, 96,   0, -1, -1, 0),
    ("mnpain",  False, 96,   0, -1, -1, 0),
    ("pepain",  False, 96,   0, -1, -1, 0),
    ("slop",    False, 78,   0, -1, -1, 0),
    ("itemup",  True,  78,   0, -1, -1, 0),
    ("wpnup",   True,  78,   0, -1, -1, 0),
    ("oof",     False, 96,   0, -1, -1, 0),
    ("telept",  False, 32,   0, -1, -1, 0),
    ("posit1",  True,  98,   0, -1, -1, 0),
    ("posit2",  True,  98,   0, -1, -1, 0),
    ("posit3",  True,  98,   0, -1, -1, 0),
    ("bgsit1",  True,  98,   0, -1, -1, 0),
    ("bgsit2",  True,  98,   0, -1, -1, 0),
    ("sgtsit",  True,  98,   0, -1, -1, 0),
    ("cacsit",  True,  98,   0, -1, -1, 0),
    ("brssit",  True,  94,   0, -1, -1, 0),
    ("cybsit",  True,  92,   0, -1, -1, 0),
    ("spisit",  True,  90,   0, -1, -1, 0),
    ("bspsit",  True,  90,   0, -1, -1, 0),
    ("kntsit",  True,  90,   0, -1, -1, 0),
    ("vilsit",  True,  90,   0, -1, -1, 0),
    ("mansit",  True,  90,   0, -1, -1, 0),
    ("pesit",   True,  90,   0, -1, -1, 0),
    ("sklatk",  False, 70,   0, -1, -1, 0),
    ("sgtatk",  False, 70,   0, -1, -1, 0),
    ("skepch",  False, 70,   0, -1, -1, 0),
    ("vilatk",  False, 70,   0, -1, -1, 0),
    ("claw",    False, 70,   0, -1, -1, 0),
    ("skeswg",  False, 70,   0, -1, -1, 0),
    ("pldeth",  False, 32,   0, -1, -1, 0),
    ("pdiehi",  False, 32,   0, -1, -1, 0),
    ("podth1",  False, 70,   0, -1, -1, 0),
    ("podth2",  False, 70,   0, -1, -1, 0),
    ("podth3",  False, 70,   0, -1, -1, 0),
    ("bgdth1",  False, 70,   0, -1, -1, 0),
    ("bgdth2",  False, 70,   0, -1, -1, 0),
    ("sgtdth",  False, 70,   0, -1, -1, 0),
    ("cacdth",  False, 70,   0, -1, -1, 0),
    ("skldth",  False, 70,   0, -1, -1, 0),
    ("brsdth",  False, 32,   0, -1, -1, 0),
    ("cybdth",  False, 32,   0, -1, -1, 0),
    ("spidth",  False, 32,   0, -1, -1, 0),
    ("bspdth",  False, 32,   0, -1, -1, 0),
    ("vildth",  False, 32,   0, -1, -1, 0),
    ("kntdth",  False, 32,   0, -1, -1, 0),
    ("pedth",   False, 32,   0, -1, -1, 0),
    ("skedth",  False, 32,   0, -1, -1, 0),
    ("posact",  True, 120,   0, -1, -1, 0),
    ("bgact",   True, 120,   0, -1, -1, 0),
    ("dmact",   True, 120,   0, -1, -1, 0),
    ("bspact",  True, 100,   0, -1, -1, 0),
    ("bspwlk",  True, 100,   0, -1, -1, 0),
    ("vilact",  True, 100,   0, -1, -1, 0),
    ("noway",   False, 78,   0, -1, -1, 0),
    ("barexp",  False, 60,   0, -1, -1, 0),
    ("punch",   False, 64,   0, -1, -1, 0),
    ("hoof",    False, 70,   0, -1, -1, 0),
    ("metal",   False, 70,   0, -1, -1, 0),
    ("chgun",   False, 64,   0, -1, -1, 0),
    ("tink",    False, 60,   0, -1, -1, 0),
    ("bdopn",   False,100,   0, -1, -1, 0),
    ("bdcls",   False,100,   0, -1, -1, 0),
    ("itmbk",   False,100,   0, -1, -1, 0),
    ("flame",   False, 32,   0, -1, -1, 0),
    ("flamst",  False, 32,   0, -1, -1, 0),
    ("getpow",  False, 60,   0, -1, -1, 0),
    ("bospit",  False, 70,   0, -1, -1, 0),
    ("boscub",  False, 70,   0, -1, -1, 0),
    ("bossit",  False, 70,   0, -1, -1, 0),
    ("bospn",   False, 70,   0, -1, -1, 0),
    ("bosdth",  False, 70,   0, -1, -1, 0),
    ("manatk",  False, 70,   0, -1, -1, 0),
    ("mandth",  False, 70,   0, -1, -1, 0),
    ("sssit",   False, 70,   0, -1, -1, 0),
    ("ssdth",   False, 70,   0, -1, -1, 0),
    ("keenpn",  False, 70,   0, -1, -1, 0),
    ("keendt",  False, 70,   0, -1, -1, 0),
    ("skeact",  False, 70,   0, -1, -1, 0),
    ("skesit",  False, 70,   0, -1, -1, 0),
    ("skeatk",  False, 70,   0, -1, -1, 0),
    ("radio",   False, 60,   0, -1, -1, 0),
]

def raw_u8_mono_to_wav(raw: bytearray, sample_rate: int) -> bytes:
    """
    Convert unsigned 8‑bit mono raw PCM audio to a WAV byte string
    without using the wave module.
    """

    num_channels = 1
    bits_per_sample = 8
    audio_format = 1  # PCM
    byte_rate = sample_rate * num_channels * bits_per_sample // 8
    block_align = num_channels * bits_per_sample // 8
    data_size = len(raw)

    riff_size = 4 + (8 + 16) + (8 + data_size)

    header = bytearray()

    header += b'RIFF'
    header += struct.pack('<I', riff_size)
    header += b'WAVE'

    header += b'fmt '
    header += struct.pack('<I', 16)
    header += struct.pack('<H', audio_format)
    header += struct.pack('<H', num_channels)
    header += struct.pack('<I', sample_rate)
    header += struct.pack('<I', byte_rate)
    header += struct.pack('<H', block_align)
    header += struct.pack('<H', bits_per_sample)

    header += b'data'
    header += struct.pack('<I', data_size)

    return bytes(header) + bytes(raw)

def load_sounds(wad, sound_table):
    sounds = []

    for entry in sound_table:
        name = entry[0]

        # "none" or missing sounds → return (0, b"")
        if name == "none":
            sounds.append((0, b""))
            continue

        lump = "DS" + name.upper()
        #print(lump)

        try:
            rate, pcm = wad.extract_sound(lump)
        except KeyError:
            # Sound not found in WAD → return silence
            rate, pcm = 0, b""

        wav = raw_u8_mono_to_wav(pcm, rate)
        sounds.append(wav)

    return sounds

def play_sound_by_index(sounds, index):
    if index < len(sounds):
        wav = sounds[index]
        audioStreamId = sndmixer.wav(wav)
        sndmixer.play(audioStreamId)
        sndmixer.volume(audioStreamId, 50)
        
# ---------- Path helpers ----------

def op_split(path):
    if not path:
        return ("", "")
    r = path.rsplit("/", 1)
    if len(r) == 1:
        return ("", path)
    head = r[0] or "/"
    return (head, r[1])

def getcwd():
    return op_split(__file__)[0]


# ---------- SPI protocol constants ----------

# Opcodes
OP_BTN_REPORT      = 0xF4
OP_FREAD_GET       = 0xF8
OP_FREAD_PUT       = 0xF9
OP_PLAY_SOUND_GET  = 0xFA
OP_KEYBOARD_GET    = 0xFB
OP_RESP_ACK        = 0xFE
OP_NOP2            = 0xFF

# IRQ bits (from FPGA to ESP)
IRQ_FREAD          = 0x01
# reserve a bit for future play_sound IRQ, e.g.:
IRQ_PLAY_SOUND     = 0x02

# Max fread payload (0x400 bytes + 1 opcode)
FREAD_MAX_LEN      = 0x400


# ---------- Button handling ----------

g_btn_state = 0

def send_button_report(btn_mask, pressed):
    global g_btn_state

    if pressed:
        g_btn_state |= btn_mask
    else:
        g_btn_state &= ~btn_mask

    payload = bytearray(5)
    payload[0] = OP_BTN_REPORT
    payload[1] = (g_btn_state >> 8) & 0xFF
    payload[2] = g_btn_state & 0xFF
    payload[3] = (btn_mask >> 8) & 0xFF
    payload[4] = btn_mask & 0xFF

    mch22.fpga_send(bytes(payload))

def send_key_report(pressed, key):
    payload = bytearray(2)
    payload[0] = OP_KEYBOARD_GET
    if pressed:
        key = key | 0x80
    payload[1] = key
    mch22.fpga_send(bytes(payload))


def make_button_handler(mask):
    def handler(pressed):
        send_button_report(mask, pressed)
    return handler


def setup_buttons():
    # These masks match your original mapping
    buttons.attach(buttons.BTN_A,      make_button_handler(1 << 9))
    buttons.attach(buttons.BTN_B,      make_button_handler(1 << 10))
    buttons.attach(buttons.BTN_HOME,   make_button_handler(1 << 5))
    buttons.attach(buttons.BTN_MENU,   make_button_handler(1 << 6))
    buttons.attach(buttons.BTN_SELECT, make_button_handler(1 << 7))
    buttons.attach(buttons.BTN_START,  make_button_handler(1 << 8))
    buttons.attach(buttons.BTN_LEFT,   make_button_handler(1 << 2))
    buttons.attach(buttons.BTN_RIGHT,  make_button_handler(1 << 3))
    buttons.attach(buttons.BTN_UP,     make_button_handler(1 << 1))
    buttons.attach(buttons.BTN_DOWN,   make_button_handler(1 << 0))
    buttons.attach(buttons.BTN_PRESS,  make_button_handler(1 << 4))


# ---------- FPGA server core ----------

class FpgaFileServer:
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.file_ids = {}   # file_id -> file object or None
        self._closed = False

        # Prebuild common SPI command buffers
        self.spi_cmd_nop2      = bytearray(2)
        self.spi_cmd_nop2[0]   = OP_NOP2

        self.spi_cmd_fread_get = bytearray(1)
        self.spi_cmd_fread_get[0] = OP_FREAD_GET

        self.spi_cmd_resp_ack  = bytearray(12)
        self.spi_cmd_resp_ack[0] = OP_RESP_ACK

        self.spi_cmd_resp_sndack  = bytearray(3)
        self.spi_cmd_resp_sndack[0] = OP_RESP_ACK

        self.spi_cmd_fread_put = bytearray(FREAD_MAX_LEN + 1)
        self.spi_cmd_fread_put[0] = OP_FREAD_PUT
        
        self.spi_cmd_play_sound_get = bytearray(1)
        self.spi_cmd_play_sound_get[0] = OP_PLAY_SOUND_GET

    def close(self):
        if self._closed:
            return
        for f in self.file_ids.values():
            if f is not None:
                try:
                    f.close()
                except Exception:
                    pass
        self._closed = True

    def _get_file(self, file_id):
        if file_id in self.file_ids:
            return self.file_ids[file_id]

        filename = f'fpga_{file_id:08x}.dat'
        full_path = self.base_dir + "/" + filename

        if filename not in os.listdir(self.base_dir):
            self.file_ids[file_id] = None
            return None

        try:
            f = open(full_path, "rb")
        except OSError:
            f = None

        self.file_ids[file_id] = f
        return f

    def _handle_fread_irq(self):
        # Ask FPGA for fread request header
        mch22.fpga_send(bytes(self.spi_cmd_fread_get))
        buf = mch22.fpga_transaction(bytes(self.spi_cmd_resp_ack))

        # buf[0] == OP_RESP_ACK, buf[1] maybe status/reserved
        req_file_id = (buf[2] << 24) | (buf[3] << 16) | (buf[4] << 8) | buf[5]
        req_offset  = (buf[6] << 24) | (buf[7] << 16) | (buf[8] << 8) | buf[9]
        req_length  = ((buf[10] << 8) | buf[11]) + 1

        f = self._get_file(req_file_id)
        if f is None:
            data = bytearray(req_length)
        else:
            f.seek(req_offset, 0)
            data = f.read(req_length) or b""

        data_len = len(data)
        if data_len < req_length:
            # pad with zeros if short read
            data = data + bytes(req_length - data_len)
            data_len = req_length

        # Build and send response
        self.spi_cmd_fread_put[1:data_len + 1] = data
        mch22.fpga_send(bytes(self.spi_cmd_fread_put[:data_len + 1]))

    def _handle_play_sound_irq(self):
        global sounds
        
        # Ask FPGA for the sound_id
        mch22.fpga_send(bytes(self.spi_cmd_play_sound_get))
        resp = mch22.fpga_transaction(bytes(self.spi_cmd_resp_sndack))
        print(''.join('{:02x}'.format(x) for x in resp))
        sound_id = resp[2]

        print(f"[FPGA] play_sound request: sound_id={sound_id:02x}")
        play_sound_by_index(sounds, sound_id)

    def loop(self):
        try:
            while True:
                # Wait for any IRQ
                while True:
                    _, irqs = mch22.fpga_transaction(bytes(self.spi_cmd_nop2))
                    if irqs:
                        break
                    time.sleep_ms(1)

#                print(f'irqs: {irqs:02x}')
                # Handle fread IRQ
                if irqs & IRQ_FREAD:
                    self._handle_fread_irq()

                if irqs & IRQ_PLAY_SOUND:
                    self._handle_play_sound_irq()
        finally:
            self.close()

hid_to_doom = {
    0x04: 0x61,  # a
    0x05: 0x62,  # b
    0x06: 0x63,  # c
    0x07: 0x64,  # d
    0x08: 0x65,  # e
    0x09: 0x66,  # f
    0x0A: 0x67,  # g
    0x0B: 0x68,  # h
    0x0C: 0x69,  # i
    0x0D: 0x6A,  # j
    0x0E: 0x6B,  # k
    0x0F: 0x6C,  # l
    0x10: 0x6D,  # m
    0x11: 0x6E,  # n
    0x12: 0x6F,  # o
    0x13: 0x70,  # p
    0x14: 0x71,  # q
    0x15: 0x72,  # r
    0x16: 0x73,  # s
    0x17: 0x74,  # t
    0x18: 0x75,  # u
    0x19: 0x76,  # v
    0x1A: 0x77,  # w
    0x1B: 0x78,  # x
    0x1C: 0x79,  # y
    0x1D: 0x7A,  # z

    0x1E: 0x31,  # 1
    0x1F: 0x32,  # 2
    0x20: 0x33,  # 3
    0x21: 0x34,  # 4
    0x22: 0x35,  # 5
    0x23: 0x36,  # 6
    0x24: 0x37,  # 7
    0x25: 0x38,  # 8
    0x26: 0x39,  # 9
    0x27: 0x30,  # 0

    0x50: 0x00,  # Left
    0x4F: 0x01,  # Right
    0x51: 0x02,  # Down
    0x52: 0x03,  # Up
    0x29: 0x07,  # Escape
    0x28: 0x08,  # Enter
    0x2B: 0x09,  # Tab
    0x2A: 0x0A,  # Backspace
#   0x31: 0x0B,  # Pause
    0x2E: 0x0C,  # Equals
    0x2D: 0x0D,  # Minus
    0x3A: 0x0E,  # F1
    0x3B: 0x0F,  # F2
    0x3C: 0x10,  # F3
    0x3D: 0x11,  # F4
    0x3E: 0x12,  # F5
    0x3F: 0x13,  # F6
    0x40: 0x14,  # F7
    0x41: 0x15,  # F8
    0x42: 0x16,  # F9
    0x43: 0x17,  # F10
    0x44: 0x18,  # F11
    0x45: 0x19,  # F12

    0x2C: 0x20,  # Space

    0x2F: 0x5B,  # [
    0x30: 0x5D,  # ]
#    0x31: 0x5C,  # \
    0x33: 0x3B,  # ;
    0x34: 0x27,  # '
    0x35: 0x60,  # `
    0x36: 0x2C,  # ,
    0x37: 0x2E,  # .
    0x38: 0x2F,  # /
}

mask_to_doom = {
    0x02: 0x04, # L-Shift
    0x20: 0x04, # R-Shift
    0x01: 0x05, # L-Ctrl
    0x10: 0x05, # R-Ctrl
    0x04: 0x06, # L-Alt
    0x40: 0x06, # R-Alt
}

def my_key_handler(key, pressed):
    print("KEY:", key, "pressed" if pressed else "released")
    send_key_report(pressed, hid_to_doom[key])

def my_mask_handler(modkey, pressed):
    print("MODKEY:", modkey, "pressed" if pressed else "released")
    send_key_report(pressed, mask_to_doom[modkey])

def my_button_handler(button, pressed):
    print("Mouse button: ", button, "pressed" if pressed else "released")
    if button == 2:
        button = 1
    elif button == 1:
        button = 2
    send_key_report(pressed, 28 + button)

acc_dx = 0
acc_dy = 0
mcnt = 0
    
def my_mouse_handler(dx, dy, wheel):
    global acc_dx
    global acc_dy
    global mcnt
    
    acc_dx += dx
    acc_dy += dy
    mcnt += 1
    if mcnt == 4:
        print("Mouse move4: ", acc_dx, acc_dy)
        
        if acc_dx != 0 or acc_dy != 0:
            send_key_report(False, 31)
            send_key_report(False, acc_dx)
            send_key_report(False, acc_dy)
        
        acc_dx = 0
        acc_dy = 0
        mcnt = 0

def my_mouse_handler2(dx, dy, wheel):
    print("Mouse move2: ", dx, dy)
        
    if dx != 0 or dy != 0:
        send_key_report(False, 31)
        send_key_report(False, dx)
        send_key_report(False, dy)
    

# ---------- Top-level setup ----------

def setup(base):
    bitstream = base + "/bitstream.bin"
    print("Loading bitstream:", bitstream)

    with open(bitstream, "rb") as f:
        mch22.lcd_mode(1)
        mch22.fpga_load(f.read())

    setup_buttons()

    server = FpgaFileServer(base)
    server.loop()

base = getcwd()
kbd = BleHIDKeyboard(my_key_handler, my_mask_handler, my_button_handler, my_mouse_handler)
display.drawFill(0x000000)
display.drawText(5, 70, "Parsing WAD file...", 0xFFFFFF, "dejavusans20")
display.flush()
wad = WADFile(base + "/fpga_01c4546d.dat")
display.drawText(5, 100, "Loading DOOM sounds...", 0xFFFFFF, "dejavusans20")
display.flush()
sounds = load_sounds(wad, sound_table)
display.drawText(5, 130, "Starting...", 0xFFFFFF, "dejavusans20")
display.flush()
sndmixer.begin(4, False)
setup(base)
