has_bluetooth = True
try:
    import bluetooth
except ImportError:
    has_bluetooth = False

if has_bluetooth:
    class BleHIDKeyboard:
        # IRQ constants
        IRQ_SCAN_RESULT = 5
        IRQ_SCAN_DONE = 6
        IRQ_PERIPHERAL_CONNECT = 7
        IRQ_PERIPHERAL_DISCONNECT = 8
        IRQ_GATTC_SERVICE_RESULT = 9
        IRQ_GATTC_SERVICE_DONE = 10
        IRQ_GATTC_CHARACTERISTIC_RESULT = 11
        IRQ_GATTC_CHARACTERISTIC_DONE = 12
        IRQ_GATTC_WRITE_DONE = 17
        IRQ_GATTC_NOTIFY = 18
        IRQ_PASSKEY_ACTION = 31
        IRQ_ENCRYPTION_UPDATE = 28

        HID_SERVICE_UUID = bluetooth.UUID(0x1812)
        HID_REPORT_UUID  = bluetooth.UUID(0x2A4D)

        # Modifier bits
        MOD_LCTRL  = 0x01
        MOD_LSHIFT = 0x02
        MOD_LALT   = 0x04
        MOD_LGUI   = 0x08
        MOD_RCTRL  = 0x10
        MOD_RSHIFT = 0x20
        MOD_RALT   = 0x40
        MOD_RGUI   = 0x80

        def scan(self):
            self.connected = False
            self.pending_connect = None
            self.conn_handle = None
            self.hid_start = None
            self.hid_end = None
            self.ble.gap_scan(0, 30000, 30000, True)
            
        def __init__(self, on_key, on_mask, on_button, on_mouse):
            self.ble = bluetooth.BLE()
            self.ble.active(True)
            self.ble.irq(self._irq)

            self.on_key = on_key
            self.on_mask = on_mask
            self.on_button = on_button
            self.on_mouse = on_mouse

            self.last_keys = set()
            self.last_mods = 0
            self.last_buttons = 0

            self.scan()
            
        def _parse_ad_structure(self, adv):
            adv = bytes(adv)
            i = 0
            length = len(adv)
            uuids = set()
            while i < length:
                field_len = adv[i]
                if field_len == 0:
                    return None

                field_type = adv[i + 1]
                if field_type in (0x02, 0x03):
                    j = 2
                    while j <= field_len: 
                        uuids.add(bluetooth.UUID(adv[i + j] | adv[i + j + 1] << 8))
                        j += 2
                i += field_len + 1
            return uuids
            
        # -----------------------------
        # IRQ handler
        # -----------------------------
        def _irq(self, event, data):
            #print("IRQ:", event, "DATA:", data)
            # Pairing request
            if event == self.IRQ_PASSKEY_ACTION:
                conn_handle, action = data
                print("PASSKEY ACTION:", action)

                if action == bluetooth.PASSKEY_ACTION_DISPLAY:
                    # The peripheral wants US to display a passkey
                    # NimBLE will provide it via another event
                    print("Peripheral requests us to display a passkey")

                elif action == bluetooth.PASSKEY_ACTION_INPUT:
                    # The peripheral wants US to enter a passkey
                    print("Peripheral expects us to enter a passkey")
                    # For now, enter 000000 or prompt the user
                    self.ble.gap_passkey(conn_handle, action, 0)

                elif action == bluetooth.PASSKEY_ACTION_NUMERIC_COMPARISON:
                    print("Numeric comparison requested")
                    # Accept automatically for now
                    self.ble.gap_passkey(conn_handle, action, 1)

                return

            # Encryption
            if event == self.IRQ_ENCRYPTION_UPDATE:
                conn_handle, encrypted, authenticated, bonded, key_size = data
                print("Encryption update:", encrypted, "bonded:", bonded, "key_size:", key_size)
                return

            # Normal events
            if event == self.IRQ_SCAN_RESULT:
                addr_type, addr, adv_type, rssi, adv_data = data
                addr = bytes(addr)

                #print("adv_type:", adv_type)

                uuids = self._parse_ad_structure(bytes(adv_data))
                #print(uuids)
                if self.HID_SERVICE_UUID in uuids:
                    self.pending_connect = (addr_type, addr)
                    self.ble.gap_scan(None)

            elif event == self.IRQ_SCAN_DONE:
                if self.pending_connect:
                    addr_type, addr = self.pending_connect
                    #print("Connecting to:", addr)
                    self.ble.gap_connect(addr_type, addr)

            elif event == self.IRQ_PERIPHERAL_CONNECT:
                self.conn_handle, addr_type, addr = data
                #print("Connected:", bytes(addr))
                self.connected = True
                self.ble.gattc_discover_services(self.conn_handle)

            elif event == self.IRQ_PERIPHERAL_DISCONNECT:
                #print("Disconnected, restarting scan")
                self.scan()

            elif event == self.IRQ_GATTC_SERVICE_RESULT:
                ch, start, end, uuid = data
                if uuid == self.HID_SERVICE_UUID:
                    self.hid_start, self.hid_end = start, end

            elif event == self.IRQ_GATTC_SERVICE_DONE:
                if self.hid_start is not None:
                    self.report_handles = []
                    self.ble.gattc_discover_characteristics(
                        self.conn_handle, self.hid_start, self.hid_end
                    )

            elif event == self.IRQ_GATTC_CHARACTERISTIC_RESULT:
                ch, def_handle, value_handle, properties, uuid = data
                if uuid == self.HID_REPORT_UUID:
                    #print("Found HID report characteristic: ", value_handle)
                    self.report_handles.append(value_handle)

            elif event == self.IRQ_GATTC_CHARACTERISTIC_DONE:
                # Bouw de queue
                self.cccd_queue = [h + 1 for h in self.report_handles]
                self.cccd_index = 0

                if self.cccd_queue:
                    first = self.cccd_queue[0]
                    self.ble.gattc_write(self.conn_handle, first, b"\x01\x00", 1)
                    
            elif event == self.IRQ_GATTC_WRITE_DONE:
                conn, handle, status = data

                if status != 0:
                    print("CCCD write failed:", status, "on handle", handle)
                    return

                # Volgende CCCD
                self.cccd_index += 1
                if self.cccd_index < len(self.cccd_queue):
                    next_cccd = self.cccd_queue[self.cccd_index]
                    self.ble.gattc_write(self.conn_handle, next_cccd, b"\x01\x00", 1)
                else:
                    print("All CCCD writes completed")

            elif event == self.IRQ_GATTC_NOTIFY:
                conn_handle, value_handle, report = data
                self._handle_hid_report(bytes(report))

        # -----------------------------
        # HID report handling
        # -----------------------------
        def _handle_hid_report(self, report):
            # mouse
            if len(report) == 4:
                buttons = report[0]
                dx = report[1] if report[1] < 128 else report[1] - 256
                dy = report[2] if report[2] < 128 else report[2] - 256
                wheel = report[3] if len(report) > 3 else 0
                for bit in range(3):
                    mask = 1 << bit
                    if (buttons & mask) and not (self.last_buttons & mask):
                        self.on_button(bit, True)
                    if (self.last_buttons & mask) and not (buttons & mask):
                        self.on_button(bit, False)
                self.on_mouse(dx, dy, wheel)
                self.last_buttons = buttons
                return

            # keyboard
            if len(report) == 8:
                mods = report[0]
                keys = set(report[2:8]) - {0}

                for bit in range(8):
                    mask = 1 << bit
                    if (mods & mask) and not (self.last_mods & mask):
                        self.on_mask(mask, True)
                    if (self.last_mods & mask) and not (mods & mask):
                        self.on_mask(mask, False)

                for k in keys - self.last_keys:
                    self.on_key(k, True)

                for k in self.last_keys - keys:
                    self.on_key(k, False)

                self.last_keys = keys
                self.last_mods = mods
else:
    class BleHIDKeyboard:
        def __init__(self, on_key, on_mask, on_button, on_mouse):
            print("Micropython has no bluetooth support")
