import bluetooth
import utime
import ujson
import _thread
import time

class BleHIDKeyboard:
    # IRQ constants
    IRQ_SCAN_RESULT = 5
    IRQ_SCAN_DONE = 6
    IRQ_PERIPHERAL_CONNECT = 7
    IRQ_PERIPHERAL_DISCONNECT = 8
    IRQ_GATTC_SERVICE_RESULT = 9
    IRQ_GATTC_SERVICE_DONE = 10
    IRQ_GATTC_CHARACTERISTIC_RESULT = 11
    IRQ_GATTC_CHARACTERISTIC_DONE = 12
    IRQ_GATTC_NOTIFY = 18
    IRQ_PASSKEY_ACTION = 31
    IRQ_ENCRYPTION_UPDATE = 28

    HID_SERVICE_UUID = bluetooth.UUID(0x1812)
    HID_REPORT_UUID  = bluetooth.UUID(0x2A4D)

    BOND_FILE = "hid_bond.json"

    # Modifier bits
    MOD_LCTRL  = 0x01
    MOD_LSHIFT = 0x02
    MOD_LALT   = 0x04
    MOD_LGUI   = 0x08
    MOD_RCTRL  = 0x10
    MOD_RSHIFT = 0x20
    MOD_RALT   = 0x40
    MOD_RGUI   = 0x80

    # Keymaps
    HID_KEYMAP = { ... }  # unchanged
    HID_KEYMAP_SHIFTED = { ... }  # unchanged

    def __init__(self, on_key):
        self.ble = bluetooth.BLE()
        self.ble.active(True)
        self.ble.irq(self._irq)

        self.on_key = on_key

        self.connected = False
        self.scanning = False
        self.pending_connect = None
        self.conn_handle = None
        self.hid_start = None
        self.hid_end = None
        self.report_handle = None

        self.last_keys = set()
        self.last_mods = 0
        
        # Start scan manager thread
        _thread.start_new_thread(self._scan_manager, ())
        
        # Load bonding info
        self.bond = self._load_bond()

        # Apply bonding keys if available
        if self.bond:
            addr_type, addr = self.bond
            print("Loaded bond for", addr_type, addr)

    # -----------------------------
    # Scan manager thread
    # -----------------------------
    def _scan_manager(self):
        while True:
            if not self.connected:
                if not self.scanning:
                    # Start continuous scan
                    self.scanning = True
                    self.ble.gap_scan(0)
            else:
                if self.scanning:
                    # Stop scanning when connected
                    self.scanning = False
                    self.ble.gap_scan(None)

            time.sleep(1)

    # -----------------------------
    # IRQ handler
    # -----------------------------
    def _irq(self, event, data):
        print("IRQ:", event, "DATA:", data)
        # Pairing request
        if event == self.IRQ_PASSKEY_ACTION:
            conn_handle, action = data
            print("PASSKEY ACTION:", action)

            if action == bluetooth.PASSKEY_ACTION_DISPLAY:
                # The peripheral wants US to display a passkey
                # NimBLE will provide it via another event
                print("Peripheral requests us to display a passkey")

            elif action == bluetooth.PASSKEY_ACTION_INPUT:
                # The peripheral wants US to enter a passkey
                print("Peripheral expects us to enter a passkey")
                # For now, enter 000000 or prompt the user
                self.ble.gap_passkey(conn_handle, action, 0)

            elif action == bluetooth.PASSKEY_ACTION_NUMERIC_COMPARISON:
                print("Numeric comparison requested")
                # Accept automatically for now
                self.ble.gap_passkey(conn_handle, action, 1)

            return

        # Encryption / bonding update
        if event == self.IRQ_ENCRYPTION_UPDATE:
            conn_handle, encrypted, authenticated, bonded, key_size = data
            print("Encryption update:", encrypted, "bonded:", bonded, "key_size:", key_size)

            # First time we see a successfully encrypted link, remember this peer
            if bonded and not self.bond and self.pending_connect:
                addr_type, addr = self.pending_connect
                self._save_bond(addr_type, addr)
                self.bond = (addr_type, addr)
            return

        # Normal events
        if event == self.IRQ_SCAN_RESULT:
            addr_type, addr, adv_type, rssi, adv_data = data
            addr = bytes(addr)
            name = self._adv_decode_name(adv_data)

            # If we have a stored bond, only care about that device
            if self.bond:
                bond_type, bond_addr = self.bond
                if addr_type == bond_type and addr == bond_addr:
                    print("Found bonded device, connecting")
                    self.pending_connect = (addr_type, addr)
                    self.ble.gap_scan(None)
                    return

            # Otherwise, your existing name-based filter
            if name:
                print("SCAN:", name, addr)
            if name and "Keyboard" in name:
                self.pending_connect = (addr_type, addr)
                self.ble.gap_scan(None)

        elif event == self.IRQ_SCAN_DONE:
            if self.pending_connect:
                addr_type, addr = self.pending_connect
                print("Connecting to:", addr)
                self.ble.gap_connect(addr_type, addr)
            else:
                print("No pending connect, rescanning...")
                self.ble.gap_scan(0)

        elif event == self.IRQ_PERIPHERAL_CONNECT:
            self.conn_handle, addr_type, addr = data
            self.connected = True
            print("Connected:", bytes(addr))
            utime.sleep_ms(200)
            self.ble.gattc_discover_services(self.conn_handle)

        elif event == self.IRQ_PERIPHERAL_DISCONNECT:
            self.connected = False
            print("Disconnected, restarting scan")

        elif event == self.IRQ_GATTC_SERVICE_RESULT:
            ch, start, end, uuid = data
            if uuid == self.HID_SERVICE_UUID:
                print("Found HID service")
                self.hid_start, self.hid_end = start, end

        elif event == self.IRQ_GATTC_SERVICE_DONE:
            if self.hid_start is not None:
                self.ble.gattc_discover_characteristics(
                    self.conn_handle, self.hid_start, self.hid_end
                )

        elif event == self.IRQ_GATTC_CHARACTERISTIC_RESULT:
            ch, def_handle, value_handle, properties, uuid = data
            if uuid == self.HID_REPORT_UUID:
                print("Found HID report characteristic")
                self.report_handle = value_handle

        elif event == self.IRQ_GATTC_CHARACTERISTIC_DONE:
            if self.report_handle:
                cccd = self.report_handle + 1
                self.ble.gattc_write(self.conn_handle, cccd, b"\x01\x00", 1)
                print("Subscribed to HID notifications")

        elif event == self.IRQ_GATTC_NOTIFY:
            conn_handle, value_handle, report = data
            self._handle_hid_report(bytes(report))

    # -----------------------------
    # Bonding persistence
    # -----------------------------
    def _save_bond(self, addr_type, addr):
        print("Saving bond info")
        try:
            with open(self.BOND_FILE, "w") as f:
                f.write(ujson.dumps({
                    "addr_type": addr_type,
                    "addr": list(addr),  # bytes → list of ints
                }))
        except Exception as e:
            print("Bond save failed:", e)

    def _load_bond(self):
        try:
            with open(self.BOND_FILE, "r") as f:
                d = ujson.loads(f.read())
                return d["addr_type"], bytes(d["addr"])
        except:
            return None

    # -----------------------------
    # HID report handling
    # -----------------------------
    def _handle_hid_report(self, report):
        if len(report) != 8:
            return

        mods = report[0]
        keys = set(report[2:8]) - {0}

        for bit in range(8):
            mask = 1 << bit
            if (mods & mask) and not (self.last_mods & mask):
                self.on_key(mask, True)
            if (self.last_mods & mask) and not (mods & mask):
                self.on_key(mask, False)

        for k in keys - self.last_keys:
            self.on_key(k, True)

        for k in self.last_keys - keys:
            self.on_key(k, False)

        self.last_keys = keys
        self.last_mods = mods

    # -----------------------------
    # Advertisement name decoder
    # -----------------------------
    def _adv_decode_name(self, adv_data):
        adv = bytes(adv_data)
        i = 0
        while i < len(adv):
            length = adv[i]
            if length == 0:
                return None
            ad_type = adv[i + 1]
            if ad_type in (0x08, 0x09):
                return adv[i + 2 : i + 1 + length].decode("utf-8", "ignore")
            i += 1 + length
        return None
