import pyxel
import random
import math

# ── Sound channel / index constants ──────────────────────────────
SND_GUN            = 0
SND_SHOTGUN        = 1
SND_ENEMY_ATTACK   = 2
SND_BOSS_ATTACK    = 3
SND_PLAYER_HURT    = 4
SND_ENEMY_DEATH    = 5
SND_LEVEL_COMPLETE = 6
SND_WIN_MELODY     = 7
SND_UI_CLICK       = 8
SND_GRAVITY        = 9


def setup_sounds():
    pyxel.sound(SND_GUN).set("c3a2f2c2",             "p", "7",  "f",  8)
    pyxel.sound(SND_SHOTGUN).set("c2a1f1c1",          "n", "7",  "f",  6)
    pyxel.sound(SND_ENEMY_ATTACK).set("f3d3",          "p", "6",  "f", 10)
    pyxel.sound(SND_BOSS_ATTACK).set("c2a1f1d1",       "n", "7",  "f",  8)
    pyxel.sound(SND_PLAYER_HURT).set("e2d2c2",         "p", "7",  "f",  9)
    pyxel.sound(SND_ENEMY_DEATH).set("g3e3c3a2",       "n", "6",  "f",  9)
    pyxel.sound(SND_LEVEL_COMPLETE).set("c3e3g3c4e4g4","t", "7",  "n", 12)
    pyxel.sound(SND_WIN_MELODY).set("c3e3g3c4e4g4c4g4e4c4","t","7","f",12)
    pyxel.sound(SND_UI_CLICK).set("c4",               "p", "5",  "n", 20)
    pyxel.sound(SND_GRAVITY).set("a2c3e3d3c3a2",      "t", "6",  "v", 10)

    # ── Background music (slots 40-48, track 1) ──────────────────────────────
    # Remapped to avoid colliding with SFX slots 0-9.
    # Volumes kept low so music sits quietly behind action sounds.

    # Melody lead (volume 2 = quiet)
    pyxel.sound(40).set(
        "a3 a3 a3 a3 g3 f3 d3 r a3 a3 a3 a3 g3 f3 d3 r", "S", "2", "N", 14)
    pyxel.sound(41).set(
        "f3 f3 f3 e3 d3 r d3 e3 f3 r f3 g3 a3 a3 g3 f3 e3 r", "S", "2", "N", 14)
    pyxel.sound(42).set(
        "a3 a3 g3 f3 r f3 g3 a3 r a3 a3 g3 f3 d3 e3 f3 r", "S", "2", "V", 14)
    pyxel.sound(43).set(
        "f3 f3 e3 d3 r d3 e3 f3 r g3 g3 f3 e3 f3 g3 a3 r", "P", "3", "V", 14)
    pyxel.sound(44).set(
        "f3 f3 f3 e3 d3 r d3 e3 f3 r f3 g3 a3 a3 g3 f3 f3 r", "S", "2", "N", 14)
    pyxel.sound(45).set(
        "f3 e3 f3 e3 f3 g3 r a3 a3 a3 g3 f3 d3 e3 f3 r", "S", "1", "F", 14)
    # Bass (volume 1)
    pyxel.sound(46).set(
        "f1 f1 r r c2 c2 r r d2 d2 r r a#1 a#1 r r", "T", "1", "N", 14)
    # Kick (volume 1)
    pyxel.sound(47).set(
        "c1 r c1 r c1 r c1 r c1 r c1 r c1 r c1 r", "P", "1", "F", 7)
    # Snare/hat (volume 1)
    pyxel.sound(48).set(
        "r c1 r c1 r c1 r c1 r c1 r c1 r c1 r c1", "N", "1", "F", 7)

    mel  = [40, 41, 42, 43, 44, 42, 43, 45, 43]
    bass = [46] * len(mel)
    # Channels 0 & 1 left empty — reserved for SFX only.
    # Melody on ch2, bass on ch3; SFX never touch those channels.
    pyxel.music(1).set([], [], mel, bass)


WIDTH = 512
HEIGHT = 288

GRAVITY = 0.38
MAX_FALL = 8

PLAYER_W = 18
PLAYER_H = 18

LEVEL_COUNT_FOR_WIN = 10
MAX_LIVES = 5

SHOP_STRENGTH_COST = 12
SHOP_SPEED_COST = 10
SHOP_JUMP_COST = 10
SHOP_HEART_COST = 14
SHOP_AMMO_COST = 8

GUN_TYPES = [
    {"name": "PI", "full_name": "PISTOL",      "rate": 12, "count": 1, "spread": 0.00, "dmg": 1, "col": 10, "ammo": -1, "range": 180, "unlock_level": 1},
    {"name": "SH", "full_name": "SHOTGUN",     "rate": 25, "count": 5, "spread": 0.22, "dmg": 2, "col": 9,  "ammo": 20, "range": 95,  "unlock_level": 2},
    {"name": "RA", "full_name": "RAPID FIRE",  "rate": 4,  "count": 1, "spread": 0.03, "dmg": 1, "col": 11, "ammo": 50, "range": 150, "unlock_level": 3},
    {"name": "SN", "full_name": "SNIPER",      "rate": 38, "count": 1, "spread": 0.00, "dmg": 5, "col": 15, "ammo": 10, "range": 260, "unlock_level": 5},
    {"name": "RO", "full_name": "ROCKET",      "rate": 52, "count": 1, "spread": 0.00, "dmg": 8, "col": 8,  "ammo": 6,  "range": 170, "unlock_level": 7},
]

KITTY_NAMES  = ["PINK",  "DARK",  "RED"]
KITTY_ACCENT = [14,      6,       8]
KITTY_TRAITS = [
    {"name": "HIGH JUMP", "jump": 1, "speed": 0.0, "strength": 0, "grav_cooldown_bonus": 0},
    {"name": "FAST",      "jump": 0, "speed": 0.5, "strength": 0, "grav_cooldown_bonus": 0},
    {"name": "STRONG",    "jump": 0, "speed": 0.0, "strength": 1, "grav_cooldown_bonus": 25},
]

SPRITE_DEST: dict = {}


def load_all_sprites() -> None:
    global SPRITE_DEST

    configs = [
        ("pink",   "res.pyxres",         0,  1,  0,  4,  3,  0,  0),
        ("dark",   "HelloKitty2.pyxres", 0, 17, 16, 20, 19,  4,  0),
        ("red",    "HelloKitty3.pyxres", 0,  0,  0,  3,  3,  8,  0),
        ("enemy1", "Badguy1.pyxres",     0,  0,  0,  3,  4,  0,  4),
        ("enemy2", "Badguy2.pyxres",     0,  0,  0,  3,  4,  4,  4),
        ("enemy3", "Badguy3.pyxres",     0,  1,  0,  5,  3,  8,  4),
        ("boss",   "Boss.pyxres",        0,  0,  0,  2,  7,  0,  9),
    ]

    def _load_images_only(filename):
        for call in [
            lambda: pyxel.load(filename, False, True, True, True),
            lambda: pyxel.load(filename, True, False, False, False),
            lambda: pyxel.load(filename),
        ]:
            try:
                call()
                return
            except Exception:
                continue

    collected: dict = {}
    for name, filename, src_bank, tx1, ty1, tx2, ty2, dtx, dty in configs:
        sx, sy = tx1 * 8, ty1 * 8
        w = (tx2 - tx1 + 1) * 8
        h = (ty2 - ty1 + 1) * 8
        dx, dy = dtx * 8, dty * 8
        _load_images_only(filename)
        img = pyxel.image(src_bank)
        pixels = [img.pget(sx + c, sy + r) for r in range(h) for c in range(w)]
        collected[name] = (w, h, dx, dy, pixels)

    for name, (w, h, dx, dy, pixels) in collected.items():
        img = pyxel.image(0)
        for r in range(h):
            for c in range(w):
                img.pset(dx + c, dy + r, pixels[r * w + c])

    SPRITE_DEST = {
        "pink":   (0,   0, 32, 32, 11),
        "dark":   (32,  0, 32, 32, 11),
        "red":    (64,  0, 32, 32, 11),
        "enemy1": (0,  32, 32, 40,  7),
        "enemy2": (32, 32, 32, 40,  7),
        "enemy3": (64, 32, 40, 32,  7),
        "boss":   (0,  72, 24, 64, 11),
    }


# FIX: added draw_h parameter to allow cropping sprite bottom (removes black bar on enemies)
def blt_sprite(name: str, screen_x: int, screen_y: int,
               flip: bool = False, flash: bool = False, draw_h: int = None) -> None:
    u, v, w, h, colkey = SPRITE_DEST[name]
    if draw_h is not None:
        h = draw_h
    blt_w = -w if flip else w
    if flash:
        for c in range(1, 16):
            pyxel.pal(c, 7)
    pyxel.blt(screen_x, screen_y, 0, u, v, blt_w, h, colkey)
    if flash:
        pyxel.pal()


def draw_sprite(sprite, dx, dy, scale=1, flash=False, flipped=False):
    rows = len(sprite)
    cols = len(sprite[0]) if rows else 0
    for r in range(rows):
        for c in range(cols):
            col = sprite[r][c]
            if col == 0:
                continue
            use_col = 7 if flash else col
            xx = dx + ((cols - 1 - c) if flipped else c) * scale
            yy = dy + r * scale
            if scale == 1:
                pyxel.pset(xx, yy, use_col)
            else:
                pyxel.rect(xx, yy, scale, scale, use_col)


def draw_gun_graphic(gun_name, x, y, col, facing=1):
    if gun_name == "PI":
        pyxel.rect(x, y + 1, 7, 3, col)
        pyxel.rect(x + 1, y, 2, 1, col)
        pyxel.rect(x + 5, y + 3, 2, 2, col)
    elif gun_name == "SH":
        pyxel.rect(x, y + 1, 10, 3, col)
        pyxel.rect(x + 1, y, 3, 1, col)
        pyxel.rect(x + 3, y + 2, 1, 4, col)
        pyxel.rect(x + 7, y + 2, 3, 1, col)
    elif gun_name == "RA":
        pyxel.rect(x, y + 1, 11, 2, col)
        pyxel.rect(x + 2, y, 6, 1, col)
        pyxel.rect(x + 1, y + 3, 4, 2, col)
        pyxel.pset(x + 10, y + 1, 15)
        pyxel.pset(x + 10, y + 2, 15)
    elif gun_name == "SN":
        pyxel.rect(x, y + 2, 14, 2, col)
        pyxel.rect(x + 3, y + 1, 4, 1, col)
        pyxel.rect(x + 2, y + 3, 3, 2, col)
        pyxel.pset(x + 13, y + 2, 15)
    elif gun_name == "RO":
        pyxel.rect(x, y + 1, 11, 4, col)
        pyxel.circ(x + 5, y + 3, 3, col)
        pyxel.pset(x + 10, y + 3, 10)
        pyxel.pset(x + 11, y + 3, 9)


def rects_overlap(a, b):
    return (
        a[0] < b[0] + b[2]
        and a[0] + a[2] > b[0]
        and a[1] < b[1] + b[3]
        and a[1] + a[3] > b[1]
    )


def clamp(v, lo, hi):
    return max(lo, min(v, hi))


def point_in_rect(px, py, r):
    return r[0] <= px < r[0] + r[2] and r[1] <= py < r[1] + r[3]


def hazard_at_point(px, py, spikes):
    for s in spikes:
        if point_in_rect(px, py, s):
            return True
    return False


def has_ground_below(px, py, solids, depth=8):
    probe = (px, py, 2, depth)
    for s in solids:
        if rects_overlap(probe, s):
            return True
    return False


def is_danger_ahead(entity, solids, spikes, direction, look_distance=8):
    front_x = entity.x + entity.w + look_distance if direction > 0 else entity.x - look_distance
    feet_y = entity.y + entity.h + 1
    if hazard_at_point(front_x, feet_y - 2, spikes):
        return True
    if hazard_at_point(front_x, feet_y + 2, spikes):
        return True
    if not has_ground_below(front_x, feet_y, solids, 10):
        return True
    return False


class LevelPortal:
    def __init__(self, x, y, level_no, is_final=False):
        self.x = float(x)
        self.y = float(y)
        self.w = 28
        self.h = 48
        self.level_no = level_no
        self.is_final = is_final
        self.anim = random.randint(0, 60)
        self.active = True

    def rect(self):
        return (int(self.x), int(self.y), self.w, self.h)

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        self.anim += 1
        t = self.anim

        if self.is_final:
            col1 = [10, 9, 14, 13, 15][t // 6 % 5]
            col2 = [15, 14, 13, 9, 10][(t // 6 + 2) % 5]
            col_inner = [9, 10, 15, 14][(t // 8) % 4]
            label = "FINISH"
            label_col = 10
        else:
            level_colors = [
                (13, 5, 1),(11, 3, 1),(9, 2, 1),(8, 2, 0),(12, 1, 0),
                (14, 5, 1),(15, 7, 1),(13, 2, 0),(10, 3, 0),
            ]
            ci = min(self.level_no - 1, len(level_colors) - 1)
            col1, col2, col_inner = level_colors[ci]
            label = f"LVL{self.level_no + 1}"
            label_col = col1

        pyxel.rect(px + 4, py, 20, self.h, col2)
        pyxel.rectb(px + 4, py, 20, self.h, col1)
        inner_col = col_inner if t % 16 < 8 else col1
        pyxel.rect(px + 6, py + 2, 16, self.h - 4, inner_col)
        for i in range(6):
            ang = t * 0.12 + i * (math.tau / 6)
            rx = px + 14 + int(math.cos(ang) * 6)
            ry = py + self.h // 2 + int(math.sin(ang) * 12)
            if py <= ry <= py + self.h:
                pyxel.pset(rx, ry, col1 if i % 2 == 0 else col2)
        glow_col = col1 if t % 20 < 10 else 15
        pyxel.circ(px + 14, py + self.h // 2, 4, glow_col)
        pyxel.pset(px + 14, py + self.h // 2, 15)
        pyxel.rect(px, py + 4, 6, self.h - 4, col1)
        pyxel.rect(px + 22, py + 4, 6, self.h - 4, col1)
        pyxel.rectb(px, py + 4, 6, self.h - 4, col2)
        pyxel.rectb(px + 22, py + 4, 6, self.h - 4, col2)
        pyxel.rect(px, py, self.w, 6, col1)
        pyxel.rectb(px, py, self.w, 6, col2)
        for i in range(4):
            spark_t = (t + i * 15) % 60
            spark_x = px + 14 + int(math.cos(t * 0.07 + i * 1.5) * 16)
            spark_y = py + self.h // 2 + int(math.sin(t * 0.07 + i * 1.5) * 22)
            if spark_t < 30:
                pyxel.pset(spark_x, spark_y, col1)
        lx = px + self.w // 2 - len(label) * 2
        ly = py - 10
        pyxel.rect(lx - 2, ly - 1, len(label) * 4 + 4, 9, 0)
        pyxel.rectb(lx - 2, ly - 1, len(label) * 4 + 4, 9, label_col)
        pyxel.text(lx, ly + 1, label, label_col)
        if t % 40 < 20:
            hint = "ENTER"
            pyxel.text(px + self.w // 2 - len(hint) * 2, py + self.h + 3, hint, 7)


class MovingPlatform:
    def __init__(self, x, y, w, move_x=0, move_y=0, range_val=60, speed=0.8):
        self.x = float(x)
        self.y = float(y)
        self.w = w
        self.h = 8
        self.ox = float(x)
        self.oy = float(y)
        self.move_x = move_x
        self.move_y = move_y
        self.range_val = range_val
        self.speed = speed
        self.t = 0.0
        self.prev_x = float(x)
        self.prev_y = float(y)

    def update(self):
        self.prev_x = self.x
        self.prev_y = self.y
        self.t += self.speed * 0.02
        if self.move_x:
            self.x = self.ox + math.sin(self.t) * self.range_val * self.move_x
        if self.move_y:
            self.y = self.oy + math.sin(self.t) * self.range_val * self.move_y

    def rect(self):
        return (int(self.x), int(self.y), self.w, self.h)

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        pyxel.rect(px, py, self.w, self.h, 3)
        pyxel.rect(px, py, self.w, 2, 11)
        pyxel.rectb(px, py, self.w, self.h, 11)


class Crusher:
    def __init__(self, x, ceiling_y, width=24):
        self.x = x
        self.ceiling_y = ceiling_y
        self.w = width
        self.h = 16
        self.y = float(ceiling_y)
        self.state = "wait"
        self.timer = random.randint(80, 160)
        self.speed_down = 5.0
        self.speed_up = 1.5
        self.bottom_y = float(ceiling_y)
        self.active_y = ceiling_y + HEIGHT
        self.warning_flash = 0

    def update(self, solids):
        if self.state in ("wait", "warn", "hold"):
            self.timer -= 1
        if self.state == "wait":
            if self.timer <= 0:
                self.state = "warn"
                self.timer = 30
                self.warning_flash = 30
        elif self.state == "warn":
            self.warning_flash -= 1
            if self.timer <= 0:
                self.state = "drop"
                self.timer = 0
        elif self.state == "drop":
            self.y += self.speed_down
            hit_solid = False
            for s in solids:
                if rects_overlap((self.x, self.y, self.w, self.h), s):
                    self.y = s[1] - self.h
                    hit_solid = True
                    break
            if hit_solid or self.y >= self.active_y:
                self.bottom_y = self.y
                self.state = "hold"
                self.timer = 50
        elif self.state == "hold":
            if self.timer <= 0:
                self.state = "rise"
        elif self.state == "rise":
            self.y -= self.speed_up
            if self.y <= self.ceiling_y:
                self.y = self.ceiling_y
                self.state = "wait"
                self.timer = random.randint(120, 200)

    def rect(self):
        return (self.x, int(self.y), self.w, self.h)

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        chain_top = self.ceiling_y
        for cy in range(int(chain_top), py, 6):
            pyxel.rect(px + self.w // 2 - 1, cy, 3, 4, 5)
        col = 8 if self.state in ("drop", "hold") else (9 if self.warning_flash % 6 < 3 else 5)
        pyxel.rect(px, py, self.w, self.h, col)
        pyxel.rectb(px, py, self.w, self.h, 7)
        for i in range(0, self.w, 6):
            pyxel.tri(px + i, py + self.h, px + i + 3, py + self.h + 5, px + i + 6, py + self.h, 15)
        if self.warning_flash > 0:
            if self.warning_flash % 4 < 2:
                pyxel.rectb(px - 2, py - 2, self.w + 4, self.h + 10, 9)


class TopSpike:
    def __init__(self, x, y, w):
        self.x = x
        self.y = y
        self.w = w
        self.h = 10

    def rect(self):
        return (self.x, self.y, self.w, self.h)

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        for i in range(0, self.w, 8):
            pyxel.rect(px + i, py, 8, 4, 5)
            pyxel.tri(px + i, py + 4, px + i + 4, py + 10, px + i + 8, py + 4, 8)
            pyxel.pset(px + i + 4, py + 8, 9)


class LavaPool:
    def __init__(self, x, y, w):
        self.x = x
        self.y = y
        self.w = w
        self.h = 12  # collision height (surface only)
        self.t = 0

    def update(self):
        self.t += 1

    def rect(self):
        return (self.x, self.y, self.w, self.h)

    # FIX: extend lava visually to the bottom of the screen
    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        pool_right = px + self.w
        if pool_right <= 0 or px > WIDTH:
            return
        # Full depth lava fill to bottom of screen
        lava_depth = HEIGHT - py
        pyxel.rect(px, py, self.w, lava_depth, 8)
        # Animated surface glow layer
        pyxel.rect(px, py, self.w, 6, 9)
        # Wave ripples on surface
        for i in range(0, self.w, 4):
            wave = int(math.sin((i + self.t * 2) * 0.4) * 2)
            pyxel.rect(px + i, py + wave, 4, 3, 9)
        # Bubble sparkles
        for i in range(0, self.w, 8):
            bx = px + i + (self.t * 2 % 8)
            if 0 <= bx < pool_right:
                pyxel.pset(bx, py + 1, 10)


class Particle:
    def __init__(self, x, y, col):
        self.x = float(x)
        self.y = float(y)
        ang = random.uniform(0, math.tau)
        spd = random.uniform(1, 3.5)
        self.vx = math.cos(ang) * spd
        self.vy = math.sin(ang) * spd - 1.5
        self.col = col
        self.life = random.randint(14, 28)

    def update(self):
        self.x += self.vx
        self.vy += 0.18
        self.y += self.vy
        self.life -= 1

    def draw(self, cam_x):
        if self.life > 0:
            pyxel.pset(int(self.x - cam_x), int(self.y), self.col)


class Bullet:
    def __init__(self, x, y, vx, vy, col, dmg, max_range=180):
        self.x = float(x)
        self.y = float(y)
        self.vx = vx
        self.vy = vy
        self.col = col
        self.dmg = dmg
        self.alive = True
        self.start_x = float(x)
        self.start_y = float(y)
        self.max_range = max_range

    def update(self, solids):
        old_x = self.x
        old_y = self.y
        self.x += self.vx
        self.y += self.vy
        if abs(self.x - self.start_x) + abs(self.y - self.start_y) > self.max_range:
            self.alive = False
            return
        rect = (self.x - 2, self.y - 2, 4, 4)
        for s in solids:
            if rects_overlap(rect, s):
                self.alive = False
                self.x = old_x
                self.y = old_y
                break

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        pyxel.rect(sx - 2, sy - 1, 5, 3, self.col)
        pyxel.pset(sx + (2 if self.vx >= 0 else -2), sy, 15)


class Rocket:
    def __init__(self, x, y, direction, max_range=170, col=8, dmg=8):
        self.x = float(x)
        self.y = float(y)
        self.vx = direction * 5
        self.vy = 0
        self.col = col
        self.alive = True
        self.exploded = False
        self.timer = 0
        self.ex_x = self.x
        self.ex_y = self.y
        self.direction = direction
        self.start_x = float(x)
        self.max_range = max_range
        self.dmg = dmg

    def update(self, solids):
        if self.exploded:
            self.timer -= 1
            if self.timer <= 0:
                self.alive = False
            return
        self.x += self.vx
        self.y += self.vy
        if abs(self.x - self.start_x) > self.max_range:
            self.explode()
            return
        rect = (self.x - 4, self.y - 2, 8, 4)
        for s in solids:
            if rects_overlap(rect, s):
                self.explode()
                return

    def explode(self):
        if not self.exploded:
            self.exploded = True
            self.ex_x = self.x
            self.ex_y = self.y
            self.timer = 16

    def explosion_rect(self):
        if not self.exploded:
            return None
        return (self.ex_x - 22, self.ex_y - 22, 44, 44)

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        if not self.exploded:
            pyxel.rect(sx - 4, sy - 2, 8, 4, self.col)
            for i in range(3):
                pyxel.pset(sx - self.direction * (5 + i * 2), sy + random.randint(-1, 1), 10 if i == 0 else 9)
        else:
            r = 18 - self.timer
            ex = int(self.ex_x - cam_x)
            ey = int(self.ex_y)
            pyxel.circ(ex, ey, r, 10)
            pyxel.circ(ex, ey, max(1, r - 4), 9)
            pyxel.circ(ex, ey, max(1, r - 8), 8)


class EnemyBullet:
    def __init__(self, x, y, vx):
        self.x = float(x)
        self.y = float(y)
        self.vx = vx
        self.life = 75
        self.alive = True

    def update(self, solids):
        self.x += self.vx
        self.life -= 1
        rect = (self.x - 2, self.y - 2, 4, 4)
        for s in solids:
            if rects_overlap(rect, s):
                self.alive = False
                return
        if self.life <= 0:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        pyxel.circ(sx, int(self.y), 2, 8 if self.life % 8 < 4 else 9)


class BlackHole:
    def __init__(self, x, y, direction, bonus_damage=0):
        self.x = float(x)
        self.y = float(y)
        self.vx = direction * 4
        self.radius = 0
        self.max_radius = 40
        self.life = 120
        self.alive = True
        self.anim = 0
        self.damage_tick = 0
        self.bonus_damage = bonus_damage

    def update(self, enemies, boss, solids):
        if not self.alive:
            return
        self.anim += 1
        self.damage_tick += 1
        if self.radius < self.max_radius:
            next_x = self.x + self.vx
            hit = False
            for s in solids:
                if rects_overlap((next_x - 6, self.y - 6, 12, 12), s):
                    hit = True
                    break
            if not hit:
                self.x = next_x
            self.radius += 2
        else:
            self.life -= 1
            if self.life <= 0:
                self.alive = False
                return

        dmg = 2 + self.bonus_damage
        suction_range = 200

        for e in enemies:
            if not e.alive:
                continue
            ex = e.x + e.w / 2
            ey = e.y + e.h / 2
            dx = self.x - ex
            dy = self.y - ey
            dist = max(1, math.sqrt(dx * dx + dy * dy))
            if dist < suction_range:
                pull = 2.8 if dist > self.radius else 5.5
                e.vx += (dx / dist) * pull
                e.vy += (dy / dist) * 0.9
                if dist < self.radius * 0.65:
                    e.vx *= 0.78
                if self.damage_tick % 8 == 0 and dist < self.radius + 18:
                    e.hp -= dmg
                    e.hit_flash = 4
                    if e.hp <= 0:
                        e.alive = False

        if boss and boss.alive:
            bx = boss.x + boss.w / 2
            by = boss.y + boss.h / 2
            dx = self.x - bx
            dy = self.y - by
            dist = max(1, math.sqrt(dx * dx + dy * dy))
            if dist < 150:
                boss.vx += (dx / dist) * 0.45
                boss.vy += (dy / dist) * 0.08
                if self.damage_tick % 10 == 0 and dist < self.radius + 20:
                    boss.take_hit(dmg)

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        r = int(self.radius)
        pyxel.circb(sx, sy, r + 2, 13)
        pyxel.circb(sx, sy, r + 5, 2)
        pyxel.circ(sx, sy, max(1, r - 3), 0)
        pyxel.circ(sx, sy, max(1, r - 7), 1)
        for i in range(6):
            ang = self.anim * 0.13 + i * 1.04
            px = sx + int(math.cos(ang) * max(1, r - 2))
            py = sy + int(math.sin(ang) * max(1, r - 2))
            pyxel.pset(px, py, [14, 13, 8][i % 3])


class Body:
    def move_and_collide(self, solids):
        self.on_ground = False
        self.x += self.vx
        for s in solids:
            if rects_overlap((self.x, self.y, self.w, self.h), s):
                if self.vx > 0:
                    self.x = s[0] - self.w
                elif self.vx < 0:
                    self.x = s[0] + s[2]
                self.vx = 0
        self.y += self.vy
        for s in solids:
            if rects_overlap((self.x, self.y, self.w, self.h), s):
                if self.vy > 0:
                    self.y = s[1] - self.h
                    self.on_ground = True
                elif self.vy < 0:
                    self.y = s[1] + s[3]
                self.vy = 0


class GunPickup:
    def __init__(self, x, y, gun_idx):
        self.x = float(x)
        self.y = float(y)
        self.gun_idx = gun_idx
        self.collected = False
        self.t = random.random() * math.tau
        self.bob_anim = 0

    def update(self):
        self.t += 0.06
        self.bob_anim += 1

    def draw(self, cam_x):
        if self.collected:
            return
        sx = int(self.x - cam_x)
        sy = int(self.y + math.sin(self.t) * 3)
        gun = GUN_TYPES[self.gun_idx]
        col = gun["col"]
        pyxel.rect(sx - 9, sy - 9, 18, 18, 0)
        pyxel.rectb(sx - 9, sy - 9, 18, 18, col)
        if self.bob_anim % 30 < 15:
            pyxel.rectb(sx - 10, sy - 10, 20, 20, 13)
        draw_gun_graphic(gun["name"], sx - 7, sy - 2, col)
        pyxel.text(sx - len(gun["name"]) * 2, sy + 10, gun["name"], col)


class GravityPickup:
    def __init__(self, x, y):
        self.x = float(x)
        self.y = float(y)
        self.collected = False
        self.t = random.random() * math.tau
        self.anim = 0

    def update(self):
        self.t += 0.08
        self.anim += 1

    def draw(self, cam_x):
        if self.collected:
            return
        sx = int(self.x - cam_x)
        sy = int(self.y + math.sin(self.t) * 2)
        glow = 13 if self.anim % 20 < 10 else 12
        pyxel.circb(sx, sy, 9, glow)
        pyxel.circ(sx, sy, 6, 0)
        pyxel.circb(sx, sy, 5, 13)
        pyxel.circ(sx, sy, 3, 1)
        for ang in range(0, 360, 45):
            rx = sx + int(math.cos(math.radians(ang)) * 11)
            ry = sy + int(math.sin(math.radians(ang)) * 11)
            pyxel.pset(rx, ry, 7 if self.anim % 10 < 5 else 10)
        pyxel.text(sx - 8, sy + 12, "GRAV", 13)


class Player(Body):
    def __init__(self, x, y, skin_idx=0, unlocked_guns=None, gun_ammo=None,
                 gun_idx=0, upgrades=None, has_gravity_gun=False):
        self.x = float(x)
        self.y = float(y)
        self.w = PLAYER_W
        self.h = PLAYER_H
        self.vx = 0
        self.vy = 0
        self.on_ground = False

        self.skin_idx = skin_idx
        self.dir = 1
        self.jumps_left = 1
        self.fly_mode = False

        self.hp = 100
        self.max_hp = 100
        self.invuln = 0
        self.dead = False
        self.death_timer = 0

        self.shoot_timer = 0
        self.grav_timer = 0
        self.double_jump_flash = 0
        self.jump_puff = []

        self.unlocked_guns = unlocked_guns if unlocked_guns is not None else {0}
        self.gun_idx = gun_idx
        self.gun_ammo = gun_ammo if gun_ammo is not None else [g["ammo"] for g in GUN_TYPES]

        self.gun_switch_timer = 0
        self.anim_tick = 0
        self.walk_frame = 0

        self.upgrades = upgrades or {"strength": 0, "speed": 0, "jump": 0}
        self.has_gravity_gun = has_gravity_gun

        self.trait = KITTY_TRAITS[self.skin_idx]
        self.base_speed = 2.8 + self.trait["speed"] + 0.22 * self.upgrades["speed"]
        self.jump_bonus = self.trait["jump"] + self.upgrades["jump"]
        self.strength_bonus = self.trait["strength"] + self.upgrades["strength"]
        self.grav_cooldown_bonus = self.trait["grav_cooldown_bonus"]

    def unlocked_list(self):
        return sorted(self.unlocked_guns)

    def current_gun(self):
        return GUN_TYPES[self.gun_idx]

    def current_grav_cooldown(self):
        return max(35, 120 - self.grav_cooldown_bonus - self.upgrades["strength"] * 5)

    def update(self, solids, world_width, moving_platforms=None):
        if self.dead:
            self.death_timer -= 1
            return

        move_speed = self.base_speed if not self.fly_mode else self.base_speed + 0.7
        moving = False

        if pyxel.btn(pyxel.KEY_A) or pyxel.btn(pyxel.KEY_LEFT):
            self.vx = -move_speed
            self.dir = -1
            moving = True
        elif pyxel.btn(pyxel.KEY_D) or pyxel.btn(pyxel.KEY_RIGHT):
            self.vx = move_speed
            self.dir = 1
            moving = True
        else:
            self.vx = 0

        if moving and self.on_ground:
            self.anim_tick += 1
            if self.anim_tick >= 8:
                self.anim_tick = 0
                self.walk_frame = 1 - self.walk_frame
        else:
            self.walk_frame = 0
            self.anim_tick = 0

        if self.gun_switch_timer > 0:
            self.gun_switch_timer -= 1

        unlocked = self.unlocked_list()
        if len(unlocked) > 1:
            if pyxel.btnp(pyxel.KEY_Q):
                idx_in_list = unlocked.index(self.gun_idx) if self.gun_idx in unlocked else 0
                self.gun_idx = unlocked[(idx_in_list - 1) % len(unlocked)]
                self.gun_switch_timer = 10
            if pyxel.btnp(pyxel.KEY_E):
                idx_in_list = unlocked.index(self.gun_idx) if self.gun_idx in unlocked else 0
                self.gun_idx = unlocked[(idx_in_list + 1) % len(unlocked)]
                self.gun_switch_timer = 10

        jump_pressed = pyxel.btnp(pyxel.KEY_W) or pyxel.btnp(pyxel.KEY_UP) or pyxel.btnp(pyxel.KEY_SPACE)

        if self.fly_mode:
            if pyxel.btn(pyxel.KEY_W) or pyxel.btn(pyxel.KEY_UP) or pyxel.btn(pyxel.KEY_SPACE):
                self.vy = -3.0
            elif pyxel.btn(pyxel.KEY_S) or pyxel.btn(pyxel.KEY_DOWN):
                self.vy = 3.0
            else:
                self.vy = 0
            self.x += self.vx
            self.y += self.vy
            self.x = clamp(self.x, 0, world_width - self.w)
            self.y = clamp(self.y, 0, HEIGHT - self.h)
            self.on_ground = False
        else:
            if jump_pressed and self.jumps_left > 0:
                ground_jump = -6.2 - 0.35 * self.jump_bonus
                air_jump = -4.8 - 0.25 * self.jump_bonus
                if self.on_ground:
                    self.vy = ground_jump
                else:
                    self.vy = air_jump
                    self.double_jump_flash = 6
                    for _ in range(5):
                        self.jump_puff.append(Particle(self.x + self.w / 2, self.y + self.h, 7))
                self.jumps_left -= 1

            self.vy += GRAVITY
            if self.vy > MAX_FALL:
                self.vy = MAX_FALL

            all_solids = list(solids)
            plat_delta_x = 0
            if moving_platforms:
                for mp in moving_platforms:
                    pr = mp.rect()
                    all_solids.append(pr)
                    stand_rect = (self.x, self.y + self.h, self.w, 2)
                    if rects_overlap(stand_rect, pr):
                        plat_delta_x = mp.x - mp.prev_x

            self.move_and_collide(all_solids)
            self.x += plat_delta_x
            self.x = clamp(self.x, 0, world_width - self.w)

            if self.on_ground:
                self.jumps_left = 2

            if self.y > HEIGHT + 80:
                self.hp = 0

        if self.invuln > 0:
            self.invuln -= 1
        if self.shoot_timer > 0:
            self.shoot_timer -= 1
        if self.grav_timer > 0:
            self.grav_timer -= 1
        if self.double_jump_flash > 0:
            self.double_jump_flash -= 1

        if self.hp <= 0 and not self.dead:
            self.dead = True
            self.death_timer = 60

    def shoot(self):
        if self.dead or self.shoot_timer > 0:
            return []
        if self.gun_idx not in self.unlocked_guns:
            return []

        gun = self.current_gun()
        ammo = self.gun_ammo[self.gun_idx]
        if ammo == 0:
            return []

        if ammo > 0:
            self.gun_ammo[self.gun_idx] -= 1

        self.shoot_timer = gun["rate"]
        snd = SND_SHOTGUN if gun["name"] == "SH" else SND_GUN
        pyxel.play(0, snd)
        bullets = []
        cx = self.x + self.w / 2
        cy = self.y + 9
        dmg_bonus = self.strength_bonus
        for i in range(gun["count"]):
            spread = gun["spread"]
            angle = (i - (gun["count"] - 1) / 2) * spread
            direction = self.dir
            if gun["name"] == "RO":
                bullets.append(Rocket(cx, cy, direction, max_range=gun["range"],
                                      col=gun["col"], dmg=gun["dmg"] + dmg_bonus))
            else:
                speed = 6
                vx = direction * speed * math.cos(angle)
                vy = math.sin(angle) * 2.2
                bullets.append(Bullet(cx, cy, vx, vy, gun["col"],
                                      gun["dmg"] + dmg_bonus, max_range=gun["range"]))
        return bullets

    def fire_black_hole(self):
        if (not self.has_gravity_gun) or self.dead or self.grav_timer > 0:
            return None
        self.grav_timer = self.current_grav_cooldown()
        pyxel.play(1, SND_GRAVITY)
        return BlackHole(self.x + self.w / 2, self.y + self.h / 2, self.dir,
                         bonus_damage=self.strength_bonus)

    def take_damage(self, amount):
        if self.dead or self.invuln > 0:
            return
        self.hp -= amount
        self.invuln = 70
        self.vy = -2.4
        pyxel.play(1, SND_PLAYER_HURT)

    def draw(self, cam_x):
        if self.dead:
            return

        px = int(self.x - cam_x)
        py = int(self.y)
        flash = (self.invuln > 0 and self.invuln % 4 < 2) or self.double_jump_flash > 0

        skin_names = ["pink", "dark", "red"]
        sprite_name = skin_names[self.skin_idx]
        _, _, sw, sh, _ = SPRITE_DEST[sprite_name]

        # FIX: align sprite bottom with hitbox bottom, shifted 2px upward so feet don't clip ground
        draw_x = px - (sw - self.w) // 2
        draw_y = py + self.h - sh - 2

        blt_sprite(sprite_name, draw_x, draw_y,
                   flip=(self.dir == -1), flash=flash)

        dot_y = py - 3
        dot_x = px + self.w // 2
        if self.jumps_left >= 2:
            pyxel.pset(dot_x - 2, dot_y, 11)
            pyxel.pset(dot_x + 2, dot_y, 11)
        elif self.jumps_left == 1:
            pyxel.pset(dot_x, dot_y, 10)

        if self.fly_mode:
            pyxel.text(px - 2, py - 10, "FLY", 11)


class Enemy(Body):
    _SPRITE_MAP = {0: "enemy1", 1: "enemy2", 2: "enemy3", 3: "enemy1"}
    # FIX: per-kind bottom crop to remove black artifact rows from sprites
    _SPRITE_CROP = {"enemy1": 4, "enemy2": 4, "enemy3": 2}

    def __init__(self, x, y, hp, speed, kind):
        self.x = float(x)
        self.y = float(y)
        # FIX: enlarged hitbox to better match sprite dimensions
        self.w = 22
        self.h = 28
        self.vx = 0
        self.vy = 0
        self.on_ground = False

        self.hp = hp
        self.max_hp = hp
        self.speed = speed
        self.kind = kind
        self.alive = True
        self.hit_flash = 0
        self.shoot_timer = random.randint(20, 60)
        self.attack_cooldown = random.randint(0, 25)
        self.dir = random.choice([-1, 1])
        self.tint = [9, 8, 12, 10][kind]

        self.anim_tick = 0
        self.walk_frame = 0
        self.is_attacking = False

        self.aggro = False
        self.aggro_timer = 0

        self.x = clamp(self.x, 20, 99999)

    def can_move_direction(self, solids, spikes, direction):
        if direction == 0:
            return True
        future_rect = (self.x + direction * max(1, self.speed), self.y, self.w, self.h)
        for s in solids:
            if rects_overlap(future_rect, s):
                return False
        if self.on_ground and is_danger_ahead(self, solids, spikes, direction, 8):
            return False
        return True

    def update(self, solids, spikes, player, enemy_bullets):
        if not self.alive:
            return

        self.is_attacking = False
        self.attack_cooldown += 1

        self.vy += GRAVITY
        if self.vy > MAX_FALL:
            self.vy = MAX_FALL

        dx_player = (player.x + player.w / 2) - (self.x + self.w / 2)
        dy_player = (player.y + player.h / 2) - (self.y + self.h / 2)
        abs_dx = abs(dx_player)
        abs_dy = abs(dy_player)

        detect_range = {0: 150, 1: 180, 2: 200, 3: 130}[self.kind]
        if abs_dx < detect_range and abs_dy < 90:
            self.aggro = True
            self.aggro_timer = 160
        elif self.aggro_timer > 0:
            self.aggro_timer -= 1
            if self.aggro_timer <= 0:
                self.aggro = False

        desired_vx = 0

        if self.kind == 0:
            if self.aggro:
                self.dir = 1 if dx_player > 0 else -1
                if self.can_move_direction(solids, spikes, self.dir):
                    mult = 2.2 if abs_dx < 50 else 1.35
                    desired_vx = self.speed * mult * self.dir
                    if self.on_ground and abs_dx < 80 and abs_dy > 20 and dy_player < 0:
                        self.vy = -5.0
                else:
                    self.dir *= -1
            else:
                if not self.can_move_direction(solids, spikes, self.dir):
                    self.dir *= -1
                desired_vx = self.speed * self.dir

        elif self.kind == 1:
            self.dir = 1 if dx_player >= 0 else -1
            desired_vx = 0
            if abs_dx < 12:
                desired_vx = -self.speed * 0.4 * self.dir
            if abs_dx < 65 and abs_dy < 28 and self.attack_cooldown >= 42:
                self.is_attacking = True
                self.attack_cooldown = 0

        elif self.kind == 2:
            self.dir = 1 if dx_player >= 0 else -1
            if abs_dx < 80 and self.aggro:
                if self.can_move_direction(solids, spikes, -self.dir):
                    desired_vx = -self.speed * 0.9 * self.dir
            self.shoot_timer += 1
            shoot_interval = 44 if self.aggro else 75
            if self.shoot_timer >= shoot_interval:
                self.shoot_timer = 0
                if abs_dx < 240 and self.on_ground:
                    vx = 3.8 if dx_player > 0 else -3.8
                    enemy_bullets.append(EnemyBullet(self.x + self.w / 2, self.y + 7, vx))
                    if self.aggro and abs_dx < 180:
                        enemy_bullets.append(EnemyBullet(
                            self.x + self.w / 2, self.y + 3, vx * 0.85))

        elif self.kind == 3:
            self.dir = 1 if dx_player >= 0 else -1
            if self.aggro and abs_dx < 56:
                if self.can_move_direction(solids, spikes, self.dir):
                    desired_vx = self.speed * 1.8 * self.dir
                else:
                    self.dir *= -1
            elif self.aggro and abs_dx < 140:
                if self.can_move_direction(solids, spikes, self.dir):
                    desired_vx = self.speed * 1.2 * self.dir
                else:
                    self.dir *= -1
                if self.on_ground and dy_player < -20:
                    self.vy = -5.8
            if abs_dx < 30 and abs_dy < 24 and self.attack_cooldown >= 22:
                self.is_attacking = True
                self.attack_cooldown = 0

        self.vx = desired_vx

        if abs(self.vx) > 0.01 and self.on_ground:
            self.anim_tick += 1
            if self.anim_tick >= 10:
                self.anim_tick = 0
                self.walk_frame = 1 - self.walk_frame
        else:
            self.walk_frame = 0
            self.anim_tick = 0

        self.move_and_collide(solids)

        if self.hit_flash > 0:
            self.hit_flash -= 1

        if self.y > HEIGHT + 120:
            self.alive = False

    def sword_attack_rect(self):
        if self.kind == 3 and self.is_attacking:
            if self.dir == 1:
                return (self.x + self.w, self.y + 2, 18, 14)
            return (self.x - 18, self.y + 2, 18, 14)
        if self.kind == 1 and self.is_attacking:
            reach = 62
            if self.dir == 1:
                return (self.x - 8, self.y - 4, self.w + reach, self.h + 8)
            return (self.x - reach + 8, self.y - 4, self.w + reach, self.h + 8)
        return None

    def contact_damage(self):
        if self.kind == 3:
            return 18 if not self.is_attacking else 35
        if self.kind == 1:
            return 8 if not self.is_attacking else 38
        if self.kind == 2:
            return 16
        return 18

    def draw(self, cam_x):
        if not self.alive:
            return

        px = int(self.x - cam_x)
        py = int(self.y)
        flash = self.hit_flash > 0 and self.hit_flash % 2 == 0

        sprite_name = self._SPRITE_MAP[self.kind]
        u, v, sw, sh, colkey = SPRITE_DEST[sprite_name]

        # FIX: bottom-align sprite with hitbox bottom, centered horizontally
        draw_x = px + (self.w - sw) // 2
        draw_y = py + self.h - sh

        # FIX: crop bottom pixels to eliminate black artifact bar beneath sprite
        crop = self._SPRITE_CROP.get(sprite_name, 3)
        draw_sh = sh - crop

        blt_sprite(sprite_name, draw_x, draw_y,
                   flip=(self.dir == -1), flash=flash, draw_h=draw_sh)

        # FIX: aggro dot positioned relative to sprite top
        if self.aggro:
            dot_x = draw_x + sw // 2
            dot_y = draw_y - 4
            pyxel.pset(dot_x, dot_y, 8)
            if pyxel.frame_count % 20 < 10:
                pyxel.pset(dot_x, dot_y - 2, 8)

        # Weapon visuals — adjusted for new hitbox/sprite alignment
        if self.kind == 1:
            sc = 7
            if self.is_attacking:
                if self.dir == 1:
                    pyxel.rect(px - 8, py + self.h // 2 - 2, self.w + 58, 4, 15)
                    pyxel.rect(px + self.w + 46, py + self.h // 2 - 4, 4, 8, 7)
                    pyxel.rectb(px - 8, py + self.h // 2 - 4, self.w + 60, 12, 9)
                else:
                    pyxel.rect(px - 58 + 8, py + self.h // 2 - 2, self.w + 58, 4, 15)
                    pyxel.rect(px - 58 + 4, py + self.h // 2 - 4, 4, 8, 7)
                    pyxel.rectb(px - 58 + 2, py + self.h // 2 - 4, self.w + 60, 12, 9)
            else:
                if self.dir == 1:
                    pyxel.rect(px + self.w, py + self.h // 2 + 3, 18, 3, sc)
                    pyxel.pset(px + self.w + 17, py + self.h // 2 + 4, 15)
                else:
                    pyxel.rect(px - 18, py + self.h // 2 + 3, 18, 3, sc)
                    pyxel.pset(px - 18, py + self.h // 2 + 4, 15)

        if self.kind == 2:
            gun_x = px + self.w if self.dir == 1 else px - 6
            pyxel.rect(gun_x, py + self.h // 2, 6, 2, 7)
            pyxel.pset(gun_x + (5 if self.dir == 1 else 0), py + self.h // 2 + 1, 15)
        elif self.kind == 3:
            if self.is_attacking:
                if self.dir == 1:
                    pyxel.rect(px + self.w, py + self.h // 2 - 2, 9, 2, 15)
                    pyxel.rect(px + self.w - 1, py + self.h // 2 + 1, 10, 1, 10)
                else:
                    pyxel.rect(px - 7, py + self.h // 2 - 2, 9, 2, 15)
                    pyxel.rect(px - 7, py + self.h // 2 + 1, 10, 1, 10)
            else:
                if self.dir == 1:
                    pyxel.rect(px + self.w, py + self.h // 2, 5, 1, 7)
                else:
                    pyxel.rect(px - 3, py + self.h // 2, 5, 1, 7)

        # FIX: HP bar drawn above the sprite (not the hitbox), full sprite width
        bar_x = draw_x
        bar_y = draw_y - 6
        bar_w = sw
        fill = int((self.hp / self.max_hp) * bar_w)
        pyxel.rect(bar_x, bar_y, bar_w, 3, 0)
        pyxel.rect(bar_x, bar_y, fill, 3, 11 if self.hp > 2 else 8)
        pyxel.rectb(bar_x, bar_y, bar_w, 3, 5)


class Boss(Body):
    SPRITE_W = 24
    SPRITE_H = 64
    HITBOX_W = 24
    HITBOX_H = 48

    def __init__(self, x, y, level_no):
        self.x = float(x)
        self.y = float(y)
        self.w = self.HITBOX_W
        self.h = self.HITBOX_H
        self.vx = 0
        self.vy = 0
        self.on_ground = False

        self.level_no = level_no
        self.max_hp = max(1, int((120 + level_no * 12) * 0.7))
        self.hp = self.max_hp
        self.alive = True
        self.hit_flash = 0

        self.shoot_timer = 0
        self.spawn_timer = 0
        self.jump_timer = 0
        self.dir = 1

    def take_hit(self, dmg):
        if not self.alive:
            return
        self.hp -= dmg
        self.hit_flash = 8
        if self.hp <= 0:
            self.hp = 0
            self.alive = False

    def boss_ground_ahead(self, solids, moving_right):
        probe_x = self.x + self.w + 8 if moving_right else self.x - 8
        probe = (probe_x, self.y + self.h + 2, 8, 4)
        for s in solids:
            if rects_overlap(probe, s):
                return True
        return False

    def update(self, solids, player, boss_bullets, enemies, spawn_enemy_func):
        if not self.alive:
            return

        self.vy += GRAVITY
        if self.vy > MAX_FALL:
            self.vy = MAX_FALL

        speed = 1.1 + self.level_no * 0.03
        if player.x + player.w / 2 < self.x + self.w / 2:
            self.vx = -speed
            self.dir = -1
        else:
            self.vx = speed
            self.dir = 1

        front_probe = (self.x + (self.w if self.vx > 0 else -3), self.y + 6, 4, self.h - 10)
        wall_ahead = any(rects_overlap(front_probe, s) for s in solids)
        no_ground_ahead = not self.boss_ground_ahead(solids, self.vx > 0)

        if wall_ahead or (self.on_ground and no_ground_ahead):
            self.vx *= -1
            self.dir = 1 if self.vx >= 0 else -1

        self.jump_timer += 1
        if self.on_ground and self.jump_timer > 100 and abs(player.x - self.x) > 60:
            self.jump_timer = 0
            self.vy = -5.4

        self.move_and_collide(solids)

        self.shoot_timer += 1
        if self.shoot_timer > 42:
            self.shoot_timer = 0
            cx = self.x + self.w / 2
            cy = self.y + 13
            dx = player.x + player.w / 2 - cx
            dy = player.y + player.h / 2 - cy
            dist = max(1, math.sqrt(dx * dx + dy * dy))
            speed_b = 3.0
            boss_bullets.append(Bullet(cx, cy, dx / dist * speed_b, dy / dist * speed_b,
                                       8, 14, max_range=240))
            pyxel.play(1, SND_BOSS_ATTACK)
            if self.hp < self.max_hp * 0.45:
                boss_bullets.append(Bullet(cx, cy, dx / dist * speed_b,
                                           (dy / dist * speed_b) - 0.4, 9, 14, max_range=240))
                boss_bullets.append(Bullet(cx, cy, dx / dist * speed_b,
                                           (dy / dist * speed_b) + 0.4, 9, 14, max_range=240))

        self.spawn_timer += 1
        if self.spawn_timer > 280:
            self.spawn_timer = 0
            if len(enemies) < 4:
                spawn_enemy_func()

        if self.hit_flash > 0:
            self.hit_flash -= 1

    def draw(self, cam_x):
        if not self.alive:
            return
        px = int(self.x - cam_x)
        py = int(self.y)
        flash = self.hit_flash > 0 and self.hit_flash % 2 == 0

        sprite_offset_x = (self.w - self.SPRITE_W) // 2
        sprite_offset_y = self.h - self.SPRITE_H
        blt_sprite("boss", px + sprite_offset_x, py + sprite_offset_y,
                   flip=(self.dir == -1), flash=flash)

        bar_w = 140
        bx = WIDTH // 2 - bar_w // 2
        pyxel.rect(bx, 8, bar_w, 6, 0)
        fill = int((self.hp / self.max_hp) * bar_w)
        pyxel.rect(bx, 8, fill, 6, 8 if self.hp < self.max_hp * 0.35 else 13)
        pyxel.rectb(bx, 8, bar_w, 6, 7)
        pyxel.text(bx + 58, 1, "BOSS", 15)


class Coin:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.collected = False
        self.t = random.random() * math.tau

    def update(self):
        self.t += 0.08

    def draw(self, cam_x):
        if self.collected:
            return
        yy = int(self.y + math.sin(self.t) * 2)
        pyxel.circ(int(self.x - cam_x), yy, 4, 10)
        pyxel.circ(int(self.x - cam_x), yy, 2, 15)


class Checkpoint:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.active = False

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        pyxel.rect(px, int(self.y) - 20, 3, 20, 7)
        pyxel.tri(px + 3, int(self.y) - 20, px + 12, int(self.y) - 15,
                  px + 3, int(self.y) - 10, 11 if self.active else 5)


class Cannon:
    def __init__(self, x, y, direction, fire_rate=90):
        self.x = float(x)
        self.y = float(y)
        self.direction = direction
        self.fire_rate = fire_rate
        self.timer = random.randint(0, fire_rate)

    def update(self, enemy_bullets):
        self.timer += 1
        if self.timer >= self.fire_rate:
            self.timer = 0
            enemy_bullets.append(EnemyBullet(self.x, self.y, self.direction * 3.5))

    def draw(self, cam_x):
        px = int(self.x - cam_x)
        py = int(self.y)
        barrel_x = px + 2 if self.direction == 1 else px - 12
        pyxel.rect(barrel_x, py - 4, 10, 8, 5)
        pyxel.circ(px, py, 6, 4)
        pyxel.circ(px, py, 3, 5)
        if self.timer > self.fire_rate - 12:
            pyxel.circb(px, py, 7, 9)


# ──────────────────────────────────────────────────────────────
#  LEVEL GENERATION  —  MARIO-STYLE ZONE SYSTEM
# ──────────────────────────────────────────────────────────────

def pick_walkable_segments(solids):
    return [s for s in solids if s[2] >= 40 and s[1] >= 90]


def spawn_on_ground(solids, width_needed=22, margin=10):
    choices = []
    for x, y, w, h in pick_walkable_segments(solids):
        if w >= width_needed + margin * 2:
            choices.append((x + margin, y - 1, w - margin * 2))
    if not choices:
        return 40, 220
    seg = random.choice(choices)
    max_px = int(seg[0] + seg[2] - width_needed)
    min_px = int(seg[0])
    if max_px <= min_px:
        px = min_px
    else:
        px = random.randint(min_px, max_px)
    return px, seg[1]


def coin_arc(start_x, peak_height, end_x, ground_y, count=5):
    coins = []
    for i in range(count):
        t = i / max(1, count - 1)
        cx = int(start_x + t * (end_x - start_x))
        cy = int(ground_y - peak_height * 4 * t * (1 - t))
        coins.append((cx, cy))
    return coins


# ── Zone helpers ────────────────────────────────────────────────

# FIX: spawn enemies 30px above ground_y so feet land exactly on surface
def _add_enemy(wd, x, y, level_no, kind=None):
    hp = 2 + level_no // 2 + (1 if kind in (2, 3) else 0)
    speed = 0.88 + level_no * 0.07 + (0.2 if kind == 1 else 0)
    if kind is None:
        kind = random.choice([0, 0, 1, 2])
    # Clamp x to reasonable world bounds and spawn above ground
    x = max(30, x)
    wd["enemy_specs"].append((x, y - 30, hp, speed, kind))


def build_breather(start_x, ground_y, level_no, wd):
    """90 px flat safe platform with checkpoint + 3 coins."""
    w = 90
    wd["solids"].append((start_x, ground_y, w, HEIGHT - ground_y))
    wd["checkpoints"].append((start_x + 12, ground_y))
    for i in range(3):
        wd["coins"].append((start_x + 18 + i * 24, ground_y - 20))
    return start_x + w, ground_y


def select_zones(level_no):
    table = {
        1:  ["GRASSLAND", "STAIRCASE", "SPIKE_PIT"],
        2:  ["GRASSLAND", "PIPE_ALLEY", "MOVING_PLATFORMS", "SPIKE_PIT"],
        3:  ["GRASSLAND", "PIPE_ALLEY", "LAVA_RIVER", "ENEMY_ARENA"],
        4:  ["STAIRCASE", "CRUSHER_GAUNTLET", "LAVA_RIVER"],
        5:  ["PIPE_ALLEY", "VERTICAL_CLIMB", "CANNON_ALLEY", "ENEMY_ARENA"],
        6:  ["CRUSHER_GAUNTLET", "CANNON_ALLEY", "LAVA_RIVER", "VERTICAL_CLIMB"],
        7:  ["FORTRESS_WALL", "ISLAND_CHAIN", "CANNON_ALLEY"],
        8:  ["STAIRCASE", "CRUSHER_GAUNTLET", "CANNON_ALLEY", "PENDULUM"],
        9:  ["PIPE_ALLEY", "VERTICAL_CLIMB", "LAVA_RIVER", "CRUSHER_GAUNTLET", "CANNON_ALLEY"],
        10: ["CANNON_ALLEY", "CRUSHER_GAUNTLET", "ENEMY_ARENA"],
    }
    return table.get(level_no, ["GRASSLAND", "SPIKE_PIT"])


# ── Individual zone builders ──────────────────────────────────

def zone_grassland(sx, gy, lv, wd):
    x = sx

    def _pipe(at_x, h):
        sy = gy - h
        cy2 = sy - 10
        wd["solids"].append((at_x, sy, 20, h))
        wd["solids"].append((at_x - 3, cy2, 26, 10))
        wd["pipe_solids"].append((at_x, sy, 20, h))
        wd["pipe_solids"].append((at_x - 3, cy2, 26, 10))
        for j in range(5):
            t2 = j / 4
            wd["coins"].append((int((at_x - 18) + t2 * 60),
                                int(cy2 - 20 + abs(j - 2) * 10)))
        return cy2

    seg1 = random.randint(100, 140)
    wd["solids"].append((x, gy, seg1, HEIGHT - gy))
    for i in range(3):
        wd["coins"].append((x + 20 + i * 20, gy - 22))
    for i in range(random.randint(1, 2)):
        ex = x + 50 + i * 40
        _add_enemy(wd, ex, gy, lv, kind=0)
    x += seg1

    gap1 = clamp(14 + lv * 2, 14, 28)
    wd["coins"].append((x - 10, gy - 20))
    x += gap1

    seg2 = random.randint(100, 130)
    wd["solids"].append((x, gy, seg2, HEIGHT - gy))
    _pipe(x + 22, random.choice([40, 56]))
    fp_y = gy - random.randint(52, 72)
    fp_x = x + seg2 - 60
    wd["solids"].append((fp_x, fp_y, 44, 8))
    wd["coins"].append((fp_x + 22, fp_y - 14))
    if lv >= 2:
        _add_enemy(wd, x + seg2 - 24, gy, lv, kind=1)
    x += seg2

    gap2 = clamp(20 + lv * 3, 20, 44)
    for arc in coin_arc(x - 4, 32, x + gap2 + 4, gy, count=4):
        wd["coins"].append(arc)
    x += gap2

    seg3 = random.randint(90, 120)
    wd["solids"].append((x, gy, seg3, HEIGHT - gy))
    n_enemies = 2 if lv >= 3 else 1
    for i in range(n_enemies):
        ex = x + 20 + i * (seg3 // (n_enemies + 1))
        _add_enemy(wd, ex, gy, lv, kind=0 if i == 0 else random.choice([0, 2]))
    for dc in [(-6, -20), (0, -28), (6, -20)]:
        wd["coins"].append((x + seg3 - 16 + dc[0], gy + dc[1]))
    x += seg3

    return x, gy


def zone_staircase(sx, gy, lv, wd):
    step_w  = 36
    step_h  = 16
    n_steps = 6

    wd["solids"].append((sx, gy, 40, HEIGHT - gy))
    x = sx + 40

    for i in range(n_steps):
        py = gy - (i + 1) * step_h
        py = max(70, py)
        wd["solids"].append((x, py, step_w, HEIGHT - py))
        if i % 2 == 0 or i == n_steps - 1:
            wd["coins"].append((x + step_w // 2 - 4, py - 14))
            if i == n_steps - 1:
                wd["coins"].append((x + step_w // 2 + 4, py - 14))
        if i == n_steps // 2 and lv >= 2:
            _add_enemy(wd, x + step_w // 2, py, lv, kind=1)
        x += step_w

    top_y = gy - n_steps * step_h
    top_y = max(70, top_y)
    top_w = 48
    wd["solids"].append((x, top_y, top_w, HEIGHT - top_y))
    for dc in [(-8, -16), (0, -22), (8, -16), (0, -30)]:
        wd["coins"].append((x + top_w // 2 + dc[0], top_y + dc[1]))
    x += top_w

    for i in range(n_steps):
        py = top_y + (i + 1) * step_h
        py = min(gy, py)
        wd["solids"].append((x, py, step_w, HEIGHT - py))
        if i % 2 == 0:
            wd["coins"].append((x + step_w // 2, py - 14))
        x += step_w

    wd["solids"].append((x, gy, 40, HEIGHT - gy))
    return x + 40, gy


def zone_pipe_alley(sx, gy, lv, wd):
    pipe_pattern = [40, 72, 40, 56, 72, 40, 56][:3 + min(lv - 2, 4)]
    spacing = 78
    total_w = len(pipe_pattern) * spacing + 60
    wd["solids"].append((sx, gy, total_w, HEIGHT - gy))

    x = sx + 24

    for i, ph in enumerate(pipe_pattern):
        shaft_x = x
        cap_x   = shaft_x - 3
        shaft_y = gy - ph
        cap_y   = shaft_y - 10

        wd["solids"].append((shaft_x, shaft_y, 20, ph))
        wd["solids"].append((cap_x, cap_y, 26, 10))
        wd["pipe_solids"].append((shaft_x, shaft_y, 20, ph))
        wd["pipe_solids"].append((cap_x, cap_y, 26, 10))

        arc_left  = shaft_x - 22
        arc_right = shaft_x + 42
        for j in range(5):
            t2 = j / 4
            cx = int(arc_left + t2 * (arc_right - arc_left))
            cy = int(cap_y - 22 + abs(j - 2) * 11)
            wd["coins"].append((cx, cy))

        if i < len(pipe_pattern) - 1 and i % 2 == 0:
            guard_x = shaft_x + spacing // 2
            _add_enemy(wd, guard_x, gy, lv, kind=1)

        if lv >= 5 and ph == 72:
            _add_enemy(wd, shaft_x + 10, shaft_y, lv, kind=2)

        x += spacing

    for i in range(3):
        wd["coins"].append((sx + total_w - 40 + i * 12, gy - 20))

    return sx + total_w, gy


def zone_moving_platforms(sx, gy, lv, wd):
    n        = random.randint(4, 5)
    plat_w   = 48
    plat_gap = 60
    gap_w    = n * plat_gap + 40
    lava_y   = gy + 4
    wd["lava_pools"].append(LavaPool(sx, lava_y, gap_w))

    speed = min(1.0, 0.5 + lv * 0.04)
    rng   = 30 if lv >= 6 else 44

    heights = [gy - 70, gy - 50, gy - 80, gy - 50, gy - 70]

    for i in range(n):
        px = sx + i * plat_gap + 10
        py = heights[i % len(heights)]
        move_y = 1 if (lv >= 5 and i % 2 == 1) else 0
        mp = MovingPlatform(float(px), float(py), plat_w,
                            move_x=1, move_y=move_y, range_val=rng, speed=speed)
        wd["moving_platforms"].append(mp)
        wd["coins"].append((px + plat_w // 2, py - 12))
        if i < n - 1:
            next_py = heights[(i + 1) % len(heights)]
            mid_x   = px + plat_w + plat_gap // 2
            mid_y   = (py + next_py) // 2 - 12
            wd["coins"].append((mid_x, mid_y))

    if lv >= 5:
        wd["top_spikes"].append(TopSpike(sx + 10, 68, gap_w - 20))

    land_x = sx + gap_w
    wd["solids"].append((land_x, gy, 70, HEIGHT - gy))
    return land_x + 70, gy


def zone_crusher_gauntlet(sx, gy, lv, wd):
    zone_w = random.randint(280, 360)
    ceil_y = random.randint(80, 100)
    wd["solids"].append((sx, 0, zone_w, ceil_y))
    wd["solids"].append((sx, gy, zone_w, HEIGHT - gy))
    n = random.randint(3, 5)
    sp = zone_w // (n + 1)
    for i in range(n):
        cx = sx + (i + 1) * sp
        cr = Crusher(cx, ceil_y, width=random.randint(20, 26))
        cr.timer = 80 + i * 50
        wd["crushers"].append(cr)
        if lv >= 6:
            wd["spikes"].append((cx - 8, gy - 8, 26, 8))
    mid_y = (ceil_y + gy) // 2
    for i in range(6):
        wd["coins"].append((sx + i * (zone_w // 6) + 20, mid_y))
    return sx + zone_w, gy


def zone_spike_pit(sx, gy, lv, wd):
    zone_w  = random.randint(200, 270)
    wd["solids"].append((sx, gy, zone_w, HEIGHT - gy))
    x = sx + 16
    spike_w = 20 + lv * 2
    safe_w  = max(28, 44 - lv * 2)
    while x + spike_w + safe_w < sx + zone_w - 16:
        wd["spikes"].append((x, gy - 8, spike_w, 8))
        wd["coins"].append((x - 8, gy - 20))
        x += spike_w
        wd["coins"].append((x + safe_w // 2, gy - 20))
        if lv >= 3 and random.random() < 0.55:
            fw = safe_w + 8
            fy = gy - random.randint(52, 76)
            wd["solids"].append((x - 4, fy, fw, 8))
            wd["coins"].append((x + fw // 2 - 4, fy - 14))
        x += safe_w + 4
    return sx + zone_w, gy


def zone_lava_river(sx, gy, lv, wd):
    lava_w = random.randint(200, 320)
    wd["lava_pools"].append(LavaPool(sx, gy + 4, lava_w))
    n = random.randint(4, 7)
    sp = lava_w // n
    for i in range(n):
        px = sx + i * sp + random.randint(0, sp // 3)
        pw = random.randint(28, 44)
        py = gy - random.randint(40, 80)
        if lv >= 5 and random.random() < 0.5:
            mp = MovingPlatform(float(px), float(py), pw,
                                move_x=0, move_y=1, range_val=18, speed=0.5)
            wd["moving_platforms"].append(mp)
        else:
            wd["solids"].append((px, py, pw, 8))
        if i < n - 1:
            wd["coins"].append((px + pw + sp // 2, py - 14))
        if lv >= 7 and i == n // 2:
            cr = Crusher(px + pw // 2 - 12, max(20, py - 70), width=24)
            wd["crushers"].append(cr)
    land_x = sx + lava_w
    wd["solids"].append((land_x, gy, 60, HEIGHT - gy))
    return land_x + 60, gy


def zone_enemy_arena(sx, gy, lv, wd):
    arena_w = random.randint(250, 340)
    wd["solids"].append((sx, gy, arena_w, HEIGHT - gy))
    _add_enemy(wd, sx + arena_w - 40, gy, lv, kind=2)
    for i in range(random.randint(1, 2)):
        _add_enemy(wd, sx + (i + 1) * (arena_w // 3), gy, lv, kind=random.choice([0, 1]))
    if lv >= 6:
        _add_enemy(wd, sx + arena_w // 2, gy, lv, kind=3)
    cx, cy = sx + arena_w // 2, gy - 24
    for ddx, ddy in [(0, 0), (-14, 0), (14, 0), (0, -14), (0, 14)]:
        wd["coins"].append((cx + ddx, cy + ddy))
    for i in range(2):
        fx = sx + (i + 1) * (arena_w // 3) - 24
        wd["solids"].append((fx, gy - 60, 48, 8))
    if lv >= 5:
        wd["top_spikes"].append(TopSpike(sx, 70, arena_w))
    return sx + arena_w, gy


def zone_vertical_climb(sx, gy, lv, wd):
    zone_w = random.randint(80, 100)
    n = random.randint(5, 8)
    step = random.randint(30, 40)
    wd["solids"].append((sx, gy, zone_w, HEIGHT - gy))
    top_y = max(60, gy - n * step)
    for i in range(n):
        py = max(60, gy - (i + 1) * step)
        pw = random.randint(30, 40)
        px = (sx + random.randint(4, 18)) if i % 2 == 0 else (sx + zone_w - pw - random.randint(4, 18))
        wd["solids"].append((px, py, pw, 8))
        wd["coins"].append((px + pw // 2, py - 14))
        if lv >= 5 and i % 2 == 1:
            wd["top_spikes"].append(TopSpike(px + 4, py + 8, max(8, pw - 8)))
        if lv >= 6 and i % 2 == 0 and py > 100:
            wd["crushers"].append(Crusher(px + pw // 2 - 12, max(20, py - 70), width=24))
    if lv >= 7:
        _add_enemy(wd, sx + zone_w // 2, top_y, lv, kind=2)
    top_plat_x = sx + zone_w
    wd["solids"].append((top_plat_x, top_y, 80, HEIGHT - top_y))
    desc_steps = 3
    desc_step_h = (gy - top_y) // max(1, desc_steps + 1)
    for s in range(desc_steps):
        sy2 = top_y + (s + 1) * desc_step_h
        wd["solids"].append((top_plat_x + 80 + s * 44, sy2, 44, HEIGHT - sy2))
    end_x = top_plat_x + 80 + desc_steps * 44
    wd["solids"].append((end_x, gy, 40, HEIGHT - gy))
    return end_x + 40, gy


def zone_cannon_alley(sx, gy, lv, wd):
    zone_w = random.randint(250, 340)
    fire_rate = max(55, 90 - (lv - 5) * 7)
    wd["solids"].append((sx, gy, zone_w, HEIGHT - gy))
    n = random.randint(3, 5)
    sp = zone_w // n
    direction = 1
    for i in range(n):
        wx = sx + i * sp + sp // 2
        wh = random.randint(40, 80)
        wy = gy - wh
        wd["solids"].append((wx, wy, 10, wh))
        cy2 = wy + random.randint(10, wh // 2)
        fr = max(55, fire_rate + random.randint(-10, 10))
        wd["cannons"].append((wx + 5, cy2, direction, fr))
        direction *= -1
        if lv >= 7 and i % 2 == 0:
            safe_x = wx - 30 if direction == 1 else wx + 20
            wd["crushers"].append(Crusher(max(sx + 4, safe_x), max(20, wy - 60), width=24))
    for i in range(n - 1):
        wd["coins"].append((sx + i * sp + sp, gy - 30))
    return sx + zone_w, gy


def zone_fortress_wall(sx, gy, lv, wd):
    wd["solids"].append((sx, gy, 60, HEIGHT - gy))
    wx = sx + 60
    wall_top = max(30, gy - random.randint(100, 150))
    wall_w = random.randint(22, 30)
    total_wall_h = gy - wall_top
    gap_y = random.randint(wall_top + 30, wall_top + total_wall_h - 70)
    gap_h = random.randint(28, 36)
    top_h = gap_y - wall_top
    if top_h > 0:
        wd["solids"].append((wx, wall_top, wall_w, top_h))
        wd["top_spikes"].append(TopSpike(wx, wall_top, wall_w))
    bot_y = gap_y + gap_h
    bot_h = gy - bot_y
    if bot_h > 0:
        wd["solids"].append((wx, bot_y, wall_w, bot_h))
    wd["cannons"].append((wx + wall_w - 4, gap_y + gap_h // 2, -1, 80))
    far_x = wx + wall_w
    wd["solids"].append((far_x, gy, 80, HEIGHT - gy))
    for ddx, ddy in [(15, -20), (30, -20), (45, -20), (30, -34)]:
        wd["coins"].append((far_x + ddx, gy + ddy))
    return far_x + 80, gy


def zone_island_chain(sx, gy, lv, wd):
    n = random.randint(5, 7)
    sp = random.randint(50, 70)
    lava_w = n * sp + 40
    wd["lava_pools"].append(LavaPool(sx, gy + 4, lava_w))
    hp = 2 + lv // 2
    speed = 0.9 + lv * 0.07
    for i in range(n):
        ix = sx + i * sp + 20
        iw = random.randint(32, 48)
        iy = gy - random.randint(30, 60)
        wd["solids"].append((ix, iy, iw, 8))
        if random.random() < 0.55:
            wd["enemy_specs"].append((ix + iw // 2, iy - 30, hp, speed, random.choice([0, 2])))
        else:
            wd["coins"].append((ix + iw // 2, iy - 16))
    bridge_x = sx + lava_w
    wd["solids"].append((bridge_x, gy, 40, HEIGHT - gy))
    return bridge_x + 40, gy


def zone_pendulum(sx, gy, lv, wd):
    zone_w = 260
    wd["lava_pools"].append(LavaPool(sx, gy + 4, zone_w))
    for i in range(3):
        px = sx + i * (zone_w // 3) + 20
        py = gy - random.randint(60, 90)
        pw = random.randint(40, 56)
        mp = MovingPlatform(float(px), float(py), pw,
                            move_x=0, move_y=1, range_val=50, speed=0.6)
        mp.t = i * (math.pi / 3)
        wd["moving_platforms"].append(mp)
        wd["coins"].append((px + pw // 2, py - 14))
    land_x = sx + zone_w
    wd["solids"].append((land_x, gy, 60, HEIGHT - gy))
    return land_x + 60, gy


_ZONE_BUILDERS = {
    "GRASSLAND":        zone_grassland,
    "STAIRCASE":        zone_staircase,
    "PIPE_ALLEY":       zone_pipe_alley,
    "MOVING_PLATFORMS": zone_moving_platforms,
    "CRUSHER_GAUNTLET": zone_crusher_gauntlet,
    "SPIKE_PIT":        zone_spike_pit,
    "LAVA_RIVER":       zone_lava_river,
    "ENEMY_ARENA":      zone_enemy_arena,
    "VERTICAL_CLIMB":   zone_vertical_climb,
    "CANNON_ALLEY":     zone_cannon_alley,
    "FORTRESS_WALL":    zone_fortress_wall,
    "ISLAND_CHAIN":     zone_island_chain,
    "PENDULUM":         zone_pendulum,
}


def make_procedural_level(level_no):
    wd = {
        "solids": [], "spikes": [], "top_spikes": [], "lava_pools": [],
        "crushers": [], "moving_platforms": [], "cannons": [],
        "coins": [], "enemy_specs": [], "gun_pickups": [],
        "checkpoints": [], "pipe_solids": [], "gravity_pickup": None,
    }

    ground_y = 240
    # FIX: record the initial ground level so load_level can spawn player correctly
    start_ground_y = ground_y

    # ── Start buffer ─────────────────────────────────────────────
    wd["solids"].append((0, ground_y, 200, HEIGHT - ground_y))
    for i in range(3):
        wd["coins"].append((60 + i * 24, ground_y - 20))

    x = 200
    zones = select_zones(level_no)

    for zone_name in zones:
        builder = _ZONE_BUILDERS.get(zone_name)
        if builder:
            x, ground_y = builder(x, ground_y, level_no, wd)
        x, ground_y = build_breather(x, ground_y, level_no, wd)

    # ── Gun pickup for this level ─────────────────────────────────
    for gi, gun in enumerate(GUN_TYPES):
        if gun["unlock_level"] == level_no:
            wd["gun_pickups"].append((x - 60, ground_y - 16, gi))

    # ── Gravity pickup level 6 ───────────────────────────────────
    if level_no == 6:
        wd["gravity_pickup"] = (x - 40, ground_y - 20)

    # ── Boss or standard end platform ────────────────────────────
    is_boss = level_no in [4, 7, 10]
    boss_spawn = None

    if is_boss:
        arena_w = 480 if level_no == 10 else 300
        end_gy  = 230
        wd["solids"].append((x, end_gy, arena_w, HEIGHT - end_gy))
        if level_no == 10:
            for i, (fw, fy) in enumerate([(60, end_gy - 55), (60, end_gy - 55),
                                           (50, end_gy - 90), (50, end_gy - 90)]):
                side_x = (x + 40) if i % 2 == 0 else (x + arena_w - 40 - fw)
                wd["solids"].append((side_x, fy, fw, 8))
                wd["coins"].append((side_x + fw // 2, fy - 14))
            boss_spawn = (x + arena_w // 2 - Boss.HITBOX_W // 2, end_gy - FinalBoss.HITBOX_H)
        else:
            boss_spawn = (x + 80, end_gy - Boss.HITBOX_H)
        end_plat_x = x
        x += arena_w
        portal_x = end_plat_x + arena_w - 80
        portal_y = end_gy - 48
    else:
        end_w = 140
        end_gy = ground_y
        wd["solids"].append((x, end_gy, end_w, HEIGHT - end_gy))
        end_plat_x = x
        x += end_w
        portal_x = end_plat_x + end_w - 80
        portal_y = end_gy - 48

    world_width = x + 60
    is_final = (level_no == LEVEL_COUNT_FOR_WIN)
    bg_colors = [0, 1, 0, 1, 5, 0, 1, 5, 0, 1]

    return {
        "world_width": world_width,
        "start_ground_y": start_ground_y,   # FIX: expose for proper player spawn
        "solids": wd["solids"],
        "spikes": wd["spikes"],
        "top_spikes": wd["top_spikes"],
        "lava_pools": wd["lava_pools"],
        "crushers": wd["crushers"],
        "moving_platforms": wd["moving_platforms"],
        "cannons": wd["cannons"],
        "coins": wd["coins"],
        "gun_pickups": wd["gun_pickups"],
        "gravity_pickup": wd["gravity_pickup"],
        "checkpoints": wd["checkpoints"],
        "enemy_specs": wd["enemy_specs"],
        "pipe_solids": wd["pipe_solids"],
        "end_x": portal_x,
        "portal_x": portal_x,
        "portal_y": portal_y,
        "is_final_portal": is_final,
        "boss": is_boss,
        "boss_spawn": boss_spawn,
        "bg_color": bg_colors[min(level_no - 1, 9)],
    }



def draw_potion(x, y, col, shimmer, label, level_val, cost_str):
    pyxel.rect(x + 4, y + 6, 14, 18, 0)
    pyxel.rect(x + 5, y + 7, 12, 16, col)
    fill_h = int(16 * (1 - level_val / 5.0))
    if fill_h > 0:
        pyxel.rect(x + 5, y + 7, 12, fill_h, 0)
    if shimmer % 20 < 10:
        pyxel.pset(x + 8, y + 14, 15)
        pyxel.pset(x + 12, y + 11, 15)
    pyxel.rect(x + 7, y + 3, 8, 4, 5)
    pyxel.rect(x + 8, y + 1, 6, 3, 6)
    pyxel.rect(x + 9, y, 4, 2, 4)
    pyxel.rectb(x + 4, y + 6, 14, 18, 7)
    pyxel.rectb(x + 7, y + 3, 8, 4, 7)
    pyxel.text(x + 2, y + 26, label, 15)
    pyxel.text(x + 2, y + 34, cost_str, 10)


def draw_loading_screen(t):
    pyxel.cls(0)
    for i in range(40):
        sx = (i * 137 + t * (1 + i % 3)) % WIDTH
        sy = (i * 97 + t // 2) % HEIGHT
        pyxel.pset(int(sx), int(sy), [1, 2, 5][i % 3])

    cx = WIDTH // 2
    cy = HEIGHT // 2

    bowtie_t = t * 0.04
    scale_pulse = 1.0 + math.sin(bowtie_t * 2) * 0.05
    bw = int(110 * scale_pulse)
    bh = int(42 * scale_pulse)

    pyxel.tri(cx - bw, cy - bh // 2, cx - bw, cy + bh // 2, cx, cy, 13)
    pyxel.tri(cx + bw, cy - bh // 2, cx + bw, cy + bh // 2, cx, cy, 13)
    pyxel.tri(cx - bw, cy - bh // 2, cx - bw + 18, cy - bh // 2 + 6, cx - bw // 2, cy - bh // 4, 14)
    pyxel.tri(cx + bw, cy - bh // 2, cx + bw - 18, cy - bh // 2 + 6, cx + bw // 2, cy - bh // 4, 14)
    knot_r = int(10 * scale_pulse)
    pyxel.circ(cx, cy, knot_r, 13)
    pyxel.circ(cx, cy, knot_r - 2, 14)
    pyxel.pset(cx, cy, 15)
    for i in range(4):
        ang = bowtie_t + i * math.pi / 2
        ox = int(math.cos(ang) * 2)
        oy = int(math.sin(ang) * 2)
        pyxel.circb(cx + ox, cy + oy, knot_r + 1, 8)

    for i in range(8):
        ang = bowtie_t * 3 + i * math.tau / 8
        sx2 = cx + int(math.cos(ang) * (bw * 0.6))
        sy2 = cy + int(math.sin(ang) * (bh * 0.35))
        if 0 <= sx2 < WIDTH and 0 <= sy2 < HEIGHT:
            pyxel.pset(sx2, sy2, 15 if (t + i * 7) % 12 < 6 else 13)

    title = "KITTY SLAYER"
    char_w = 6
    scale_t = 2
    total_w = len(title) * (char_w * scale_t + 2)
    start_x = cx - total_w // 2
    title_y = cy - bh - 28

    tx = start_x + 2
    ty = title_y + 2
    for ch in title:
        if ch != " ":
            pyxel.text(tx, ty, ch, 8)
        tx += char_w * scale_t + 2

    tx = start_x
    ty = title_y
    glow_col = 14 if (t // 8) % 2 == 0 else 13
    for ch in title:
        if ch != " ":
            pyxel.text(tx - 1, ty, ch, glow_col)
            pyxel.text(tx + 1, ty, ch, glow_col)
            pyxel.text(tx, ty - 1, ch, glow_col)
            pyxel.text(tx, ty + 1, ch, glow_col)
            pyxel.text(tx, ty, ch, 14)
        tx += char_w * scale_t + 2

    sub = "THE ULTIMATE CAT ADVENTURE"
    sub_col = 13 if (t // 10) % 2 == 0 else 15
    pyxel.text(cx - len(sub) * 2, cy + bh + 14, sub, sub_col)

    if (t // 20) % 2 == 0:
        press_msg = "PRESS SPACE TO START"
        pyxel.text(cx - len(press_msg) * 2, HEIGHT - 28, press_msg, 15)

    pyxel.text(4, HEIGHT - 10, "v3.1", 5)


class FinalBoss(Boss):
    HITBOX_W = 24
    HITBOX_H = 48
    SPRITE_W = 24
    SPRITE_H = 64

    def __init__(self, x, y, level_no):
        super().__init__(x, y, level_no)
        self.max_hp = 500
        self.hp    = 500
        self.charge_timer = 0
        self.charging     = False
        self.charge_vx    = 0

    def _phase(self):
        if self.hp > self.max_hp * 0.6:
            return 1
        elif self.hp > self.max_hp * 0.3:
            return 2
        return 3

    def update(self, solids, player, boss_bullets, enemies, spawn_enemy_func):
        if not self.alive:
            return
        self.vy += GRAVITY
        if self.vy > MAX_FALL:
            self.vy = MAX_FALL

        ph = self._phase()
        speed = 1.0 + ph * 0.7

        if ph >= 2:
            self.charge_timer += 1
            if self.charge_timer > (130 if ph == 3 else 170):
                self.charge_timer = 0
                self.charging  = True
                self.charge_vx = 7.5 * (1 if player.x > self.x else -1)
                self.dir       = 1 if self.charge_vx > 0 else -1
            if self.charging:
                self.vx = self.charge_vx
                front = (self.x + (self.w if self.vx > 0 else -4), self.y + 4, 4, self.h - 8)
                if any(rects_overlap(front, s) for s in solids) or abs(self.x - player.x) < 28:
                    self.charging = False
                    self.vx = 0
            else:
                self.vx = speed * (1 if player.x + player.w / 2 >= self.x + self.w / 2 else -1)
                self.dir = 1 if self.vx > 0 else -1
        else:
            self.vx = speed * (1 if player.x + player.w / 2 >= self.x + self.w / 2 else -1)
            self.dir = 1 if self.vx > 0 else -1

        if not self.charging:
            front_probe = (self.x + (self.w if self.vx > 0 else -3), self.y + 6, 4, self.h - 10)
            wall_ahead     = any(rects_overlap(front_probe, s) for s in solids)
            no_ground_ahead = not self.boss_ground_ahead(solids, self.vx > 0)
            if wall_ahead or (self.on_ground and no_ground_ahead):
                self.vx *= -1
                self.dir = 1 if self.vx >= 0 else -1

        self.jump_timer += 1
        if self.on_ground and self.jump_timer > (60 if ph == 3 else 80):
            self.jump_timer = 0
            self.vy = -6.2

        self.move_and_collide(solids)

        self.shoot_timer += 1
        rate = {1: 38, 2: 26, 3: 18}[ph]
        if self.shoot_timer > rate:
            self.shoot_timer = 0
            cx = self.x + self.w / 2
            cy = self.y + 13
            dx = player.x + player.w / 2 - cx
            dy = player.y + player.h / 2 - cy
            dist = max(1, math.sqrt(dx * dx + dy * dy))
            spd  = 3.2 + ph * 0.4
            boss_bullets.append(Bullet(cx, cy, dx/dist*spd, dy/dist*spd, 8, 16, max_range=280))
            if ph >= 2:
                for off in (-0.55, 0.55):
                    boss_bullets.append(Bullet(cx, cy, dx/dist*spd, dy/dist*spd + off,
                                               9, 12, max_range=280))
            if ph == 3:
                for off in (-0.3, 0.3):
                    boss_bullets.append(Bullet(cx, cy, dx/dist*spd + off, dy/dist*spd,
                                               8, 10, max_range=280))

        self.spawn_timer += 1
        if self.spawn_timer > (160 if ph >= 2 else 240):
            self.spawn_timer = 0
            if len(enemies) < 8:
                spawn_enemy_func()

        if self.hit_flash > 0:
            self.hit_flash -= 1

    def draw(self, cam_x):
        if not self.alive:
            return
        px   = int(self.x - cam_x)
        py   = int(self.y)
        ph   = self._phase()
        flash = self.hit_flash > 0 and self.hit_flash % 2 == 0

        if not flash:
            if ph == 2:
                pyxel.pal(14, 9)
            elif ph == 3:
                pyxel.pal(14, 8)
                pyxel.pal(11, 8)

        soff_x = (self.w - self.SPRITE_W) // 2
        soff_y = self.h - self.SPRITE_H
        blt_sprite("boss", px + soff_x, py + soff_y,
                   flip=(self.dir == -1), flash=flash)
        pyxel.pal()

        labels = ["", "PHASE I", "PHASE II!", "PHASE III!!!"]
        lcols  = [0, 7, 9, 8]
        lbl = labels[ph]
        pyxel.rect(px - 2, py + soff_y - 12, len(lbl) * 4 + 4, 9, 0)
        pyxel.text(px, py + soff_y - 10, lbl, lcols[ph])

        bar_w = 200
        bx    = WIDTH // 2 - bar_w // 2
        pyxel.rect(bx, 8, bar_w, 8, 0)
        fill  = int((self.hp / self.max_hp) * bar_w)
        bcol  = {1: 13, 2: 9, 3: 8}[ph]
        pyxel.rect(bx, 8, fill, 8, bcol)
        pyxel.rectb(bx, 8, bar_w, 8, 7)
        for frac in (0.3, 0.6):
            mx = bx + int(frac * bar_w)
            pyxel.rect(mx - 1, 6, 3, 12, 7)
        pyxel.text(bx + bar_w // 2 - 22, 1, "!! FINAL BOSS !!", lcols[ph])


class Game:
    def __init__(self):
        pyxel.init(WIDTH, HEIGHT, title="Kitty Slayer")
        setup_sounds()
        load_all_sprites()

        self.selected_skin = 0
        self.state = "loading"
        self.loading_timer = 0
        self.level_no = 1
        self.total_score = 0
        self.lives_left = MAX_LIVES
        self.game_over_full = False

        self.unlocked_guns = {0}
        self.gun_ammo = [g["ammo"] for g in GUN_TYPES]
        self.current_gun_idx = 0
        self.has_gravity_gun = False
        self.coins_wallet = 0
        self.upgrades = {"strength": 0, "speed": 0, "jump": 0}
        self.shop_open = False
        self.shop_flash = 0
        self.shop_msg = ""
        self.shop_msg_timer = 0
        self.shop_anim = 0

        self.level_portal = None
        self.portal_enter_timer = 0

        self.god_mode = False

        self.load_level()
        pyxel.playm(1, loop=True)  # background music
        pyxel.run(self.update, self.draw)

    def reset_full_game(self):
        self.selected_skin = 0
        self.state = "skin_select"
        self.level_no = 1
        self.total_score = 0
        self.lives_left = MAX_LIVES
        self.game_over_full = False
        self.unlocked_guns = {0}
        self.gun_ammo = [g["ammo"] for g in GUN_TYPES]
        self.current_gun_idx = 0
        self.has_gravity_gun = False
        self.coins_wallet = 0
        self.upgrades = {"strength": 0, "speed": 0, "jump": 0}
        self.shop_open = False
        self.level_portal = None
        self.god_mode = False
        self.load_level()

    def load_level(self):
        data = make_procedural_level(self.level_no)
        self.world_width = data["world_width"]
        self.solids = data["solids"]
        self.spikes = data["spikes"]
        self.top_spikes = data.get("top_spikes", [])
        self.lava_pools = [LavaPool(*lp) if not isinstance(lp, LavaPool) else lp
                           for lp in data.get("lava_pools", [])]
        self.crushers = data.get("crushers", [])
        self.moving_platforms = data.get("moving_platforms", [])
        self.end_x = data["end_x"]
        self.is_boss_level = data["boss"]
        self.bg_color = data["bg_color"]

        is_final = data.get("is_final_portal", False)
        portal_x = data.get("portal_x", self.world_width - 100)
        portal_y = data.get("portal_y", 180)
        self.level_portal = LevelPortal(portal_x, portal_y, self.level_no, is_final=is_final)
        self.portal_enter_timer = 0

        # FIX: use the level's actual start ground_y so player never spawns inside terrain
        start_gy = data.get("start_ground_y", 240)
        spawn_y = start_gy - PLAYER_H - 2   # 2px above ground surface

        self.player = Player(
            40, spawn_y,
            self.selected_skin,
            unlocked_guns=set(self.unlocked_guns),
            gun_ammo=list(self.gun_ammo),
            gun_idx=self.current_gun_idx,
            upgrades=dict(self.upgrades),
            has_gravity_gun=self.has_gravity_gun,
        )
        self.bullets = []
        self.rockets = []
        self.enemy_bullets = []
        self.boss_bullets = []
        self.black_holes = []
        self.particles = []

        self.enemies = [Enemy(*e) for e in data["enemy_specs"]]
        self.coins = [Coin(*c) for c in data["coins"]]
        self.gun_pickups = [GunPickup(*g) for g in data["gun_pickups"]]
        self.gravity_pickup = GravityPickup(*data["gravity_pickup"]) if data["gravity_pickup"] else None
        self.checkpoints = [Checkpoint(*c) for c in data["checkpoints"]]
        self.cannons = [Cannon(*c) for c in data.get("cannons", [])]
        self.pipe_rect_set = frozenset(tuple(p) for p in data.get("pipe_solids", []))

        # FIX: last checkpoint also placed above ground properly
        self.last_checkpoint_x = 40
        self.last_checkpoint_y = start_gy
        self.score = 0
        self.cam_x = 0.0
        self.boss_intro_timer = 70 if self.is_boss_level else 0

        self.boss = None
        if self.is_boss_level and data["boss_spawn"]:
            if self.level_no == 10:
                self.boss = FinalBoss(data["boss_spawn"][0], data["boss_spawn"][1], self.level_no)
            else:
                self.boss = Boss(data["boss_spawn"][0], data["boss_spawn"][1], self.level_no)

        self.gun_notify_timer = 0
        self.gun_notify_text = ""
        for g in data["gun_pickups"]:
            gun = GUN_TYPES[g[2]]
            self.gun_notify_text = f"NEW GUN: {gun['full_name']}!"
            self.gun_notify_timer = 120
        if self.gravity_pickup and not self.has_gravity_gun:
            self.gun_notify_text = "GRAVITY GUN AVAILABLE LATE GAME!"
            self.gun_notify_timer = 140

        if self.level_no >= 4:
            self.notify(f"LVL {self.level_no}: BUY POTIONS! [TAB]", 200)

    def spawn_particles(self, x, y, col, count=8):
        for _ in range(count):
            self.particles.append(Particle(x, y, col))

    def save_player_state(self):
        self.unlocked_guns = set(self.player.unlocked_guns)
        self.gun_ammo = list(self.player.gun_ammo)
        self.current_gun_idx = self.player.gun_idx
        self.has_gravity_gun = self.player.has_gravity_gun

    def respawn_player(self):
        # FIX: spawn 2px above checkpoint ground level so feet don't clip
        spawn_y = self.last_checkpoint_y - PLAYER_H - 2
        self.player = Player(
            self.last_checkpoint_x,
            spawn_y,
            self.selected_skin,
            unlocked_guns=set(self.unlocked_guns),
            gun_ammo=list(self.gun_ammo),
            gun_idx=self.current_gun_idx,
            upgrades=dict(self.upgrades),
            has_gravity_gun=self.has_gravity_gun,
        )
        self.player.hp = min(self.player.max_hp, 60 + self.upgrades["jump"] * 5)
        self.bullets.clear()
        self.rockets.clear()
        self.enemy_bullets.clear()
        self.boss_bullets.clear()
        self.black_holes.clear()

    def spawn_extra_enemy_near_boss(self):
        if not self.boss or not self.boss.alive:
            return
        tries = 0
        while tries < 20:
            tries += 1
            px, py = spawn_on_ground(self.solids, 22, 8)
            if abs(px - self.boss.x) < 140 and px > 100:
                hp = 2 + self.level_no // 2
                speed = 0.95 + self.level_no * 0.06
                kind = random.choice([0, 1, 2, 3])
                self.enemies.append(Enemy(px, py - 30, hp, speed, kind))
                return

    def spike_hit(self, rect):
        for s in self.spikes:
            if rects_overlap(rect, s):
                return True
        return False

    def top_spike_hit(self, rect):
        for ts in self.top_spikes:
            if rects_overlap(rect, ts.rect()):
                return True
        return False

    def lava_hit(self, rect):
        for lp in self.lava_pools:
            if rects_overlap(rect, lp.rect()):
                return True
        return False

    def crusher_hit(self, rect):
        for cr in self.crushers:
            if cr.state in ("drop", "hold") and rects_overlap(rect, cr.rect()):
                return True
        return False

    def solid_bullet_list(self):
        sl = list(self.solids)
        for mp in self.moving_platforms:
            sl.append(mp.rect())
        return sl + self.spikes

    def lose_life(self):
        self.lives_left -= 1
        self.state = "dead"
        if self.lives_left <= 0:
            self.game_over_full = True

    def notify(self, text, timer=120):
        self.shop_msg = text
        self.shop_msg_timer = timer

    def _shop_buy_sound(self, success):
        pyxel.play(0, SND_LEVEL_COMPLETE if success else SND_UI_CLICK)

    def try_buy(self, what):
        if what == "strength":
            cost = SHOP_STRENGTH_COST + self.upgrades["strength"] * 8
            if self.coins_wallet >= cost and self.upgrades["strength"] < 5:
                self.coins_wallet -= cost
                self.upgrades["strength"] += 1
                self.notify("POWER POTION DRUNK!")
                self._shop_buy_sound(True)
            elif self.upgrades["strength"] >= 5:
                self.notify("STRENGTH MAXED", 90)
                self._shop_buy_sound(False)
            else:
                self.notify("NOT ENOUGH COINS", 90)
                self._shop_buy_sound(False)
        elif what == "speed":
            cost = SHOP_SPEED_COST + self.upgrades["speed"] * 7
            if self.coins_wallet >= cost and self.upgrades["speed"] < 5:
                self.coins_wallet -= cost
                self.upgrades["speed"] += 1
                self.notify("SWIFTNESS POTION DRUNK!")
                self._shop_buy_sound(True)
            elif self.upgrades["speed"] >= 5:
                self.notify("SPEED MAXED", 90)
                self._shop_buy_sound(False)
            else:
                self.notify("NOT ENOUGH COINS", 90)
                self._shop_buy_sound(False)
        elif what == "jump":
            cost = SHOP_JUMP_COST + self.upgrades["jump"] * 7
            if self.coins_wallet >= cost and self.upgrades["jump"] < 5:
                self.coins_wallet -= cost
                self.upgrades["jump"] += 1
                self.notify("LEAP POTION DRUNK!")
                self._shop_buy_sound(True)
            elif self.upgrades["jump"] >= 5:
                self.notify("JUMP MAXED", 90)
                self._shop_buy_sound(False)
            else:
                self.notify("NOT ENOUGH COINS", 90)
                self._shop_buy_sound(False)
        elif what == "heart":
            cost = SHOP_HEART_COST + max(0, self.lives_left - 1) * 2
            if self.coins_wallet >= cost and self.lives_left < MAX_LIVES + 3:
                self.coins_wallet -= cost
                self.lives_left += 1
                self.notify("LIFE POTION DRUNK!")
                self._shop_buy_sound(True)
            else:
                self.notify("NOT ENOUGH COINS", 90)
                self._shop_buy_sound(False)
        elif what == "ammo":
            cost = SHOP_AMMO_COST
            if self.coins_wallet >= cost:
                self.coins_wallet -= cost
                for gi, gun in enumerate(GUN_TYPES):
                    if gun["ammo"] > 0:
                        refill = max(1, gun["ammo"] // 2)
                        self.player.gun_ammo[gi] = min(gun["ammo"], self.player.gun_ammo[gi] + refill)
                self.notify("AMMO REFILLED!", 120)
                self._shop_buy_sound(True)
            else:
                self.notify("NOT ENOUGH COINS", 90)
                self._shop_buy_sound(False)

    def update_shop(self):
        self.shop_anim += 1
        if pyxel.btnp(pyxel.KEY_1):
            self.try_buy("strength")
        if pyxel.btnp(pyxel.KEY_2):
            self.try_buy("speed")
        if pyxel.btnp(pyxel.KEY_3):
            self.try_buy("jump")
        if pyxel.btnp(pyxel.KEY_4):
            self.try_buy("heart")
        if pyxel.btnp(pyxel.KEY_5):
            self.try_buy("ammo")

    def check_portal_enter(self):
        if not self.level_portal or not self.level_portal.active:
            return False
        if self.portal_enter_timer < 60:
            return False
        pr = self.level_portal.rect()
        player_rect = (self.player.x, self.player.y, self.player.w, self.player.h)
        return rects_overlap(player_rect, pr)

    def update(self):
        if self.state == "loading":
            self.loading_timer += 1
            if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
                if self.loading_timer > 30:
                    self.state = "skin_select"
            return

        if self.state == "skin_select":
            if pyxel.btnp(pyxel.KEY_LEFT) or pyxel.btnp(pyxel.KEY_A):
                self.selected_skin = (self.selected_skin - 1) % len(KITTY_NAMES)
                pyxel.play(0, SND_UI_CLICK)
            if pyxel.btnp(pyxel.KEY_RIGHT) or pyxel.btnp(pyxel.KEY_D):
                self.selected_skin = (self.selected_skin + 1) % len(KITTY_NAMES)
                pyxel.play(0, SND_UI_CLICK)
            if pyxel.btnp(pyxel.KEY_RETURN) or pyxel.btnp(pyxel.KEY_SPACE):
                pyxel.play(0, SND_LEVEL_COMPLETE)
                self.level_no = 1
                self.total_score = 0
                self.lives_left = MAX_LIVES
                self.game_over_full = False
                self.unlocked_guns = {0}
                self.gun_ammo = [g["ammo"] for g in GUN_TYPES]
                self.current_gun_idx = 0
                self.has_gravity_gun = False
                self.coins_wallet = 0
                self.upgrades = {"strength": 0, "speed": 0, "jump": 0}
                self.load_level()
                self.state = "controls"
            return

        if self.state == "controls":
            if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
                pyxel.play(0, SND_UI_CLICK)
                self.state = "play"
            return

        if self.state == "dead":
            if pyxel.btnp(pyxel.KEY_RETURN) or pyxel.btnp(pyxel.KEY_SPACE):
                if self.game_over_full:
                    self.reset_full_game()
                else:
                    self.respawn_player()
                    self.state = "play"
            return

        if self.state == "win":
            if pyxel.btnp(pyxel.KEY_RETURN) or pyxel.btnp(pyxel.KEY_SPACE):
                self.reset_full_game()
            return

        if pyxel.btnp(pyxel.KEY_TAB) or pyxel.btnp(pyxel.KEY_B):
            self.shop_open = not self.shop_open
            if not self.shop_open:
                self.gun_ammo = list(self.player.gun_ammo)

        # --- GOD MODE toggle (M) ---
        if pyxel.btnp(pyxel.KEY_M):
            self.god_mode = not self.god_mode
            self.player.fly_mode = self.god_mode

        # Sync fly_mode and handle god-mode-only cheats
        if self.god_mode:
            self.player.fly_mode = True
            if pyxel.btnp(pyxel.KEY_K) and self.boss and self.boss.alive:
                self.boss.hp = 0
                self.boss.alive = False
                self.spawn_particles(self.boss.x + self.boss.w / 2,
                                     self.boss.y + self.boss.h / 2, 10, 40)
        else:
            self.player.fly_mode = False

        if self.shop_msg_timer > 0:
            self.shop_msg_timer -= 1

        if self.shop_open:
            self.update_shop()
            return

        if self.boss_intro_timer > 0:
            self.boss_intro_timer -= 1
            return

        if self.gun_notify_timer > 0:
            self.gun_notify_timer -= 1

        self.portal_enter_timer += 1

        for mp in self.moving_platforms:
            mp.update()
        for lp in self.lava_pools:
            lp.update()
        for cr in self.crushers:
            cr.update(self.solids)
        for cn in self.cannons:
            cn.update(self.enemy_bullets)

        for puff in self.player.jump_puff:
            self.particles.append(puff)
        self.player.jump_puff.clear()

        self.player.update(self.solids, self.world_width, self.moving_platforms)

        if pyxel.btn(pyxel.KEY_Z):
            made = self.player.shoot()
            for b in made:
                if isinstance(b, Rocket):
                    self.rockets.append(b)
                else:
                    self.bullets.append(b)

        if pyxel.btnp(pyxel.KEY_X):
            bh = self.player.fire_black_hole()
            if bh:
                self.black_holes.append(bh)
                self.spawn_particles(self.player.x + self.player.w / 2,
                                     self.player.y + self.player.h / 2, 13, 10)

        solid_blocks = self.solid_bullet_list()

        for bh in self.black_holes:
            bh.update(self.enemies, self.boss, solid_blocks)
        self.black_holes = [b for b in self.black_holes if b.alive]

        for b in self.bullets:
            b.update(solid_blocks)
        for r in self.rockets:
            r.update(solid_blocks)
        for b in self.enemy_bullets:
            b.update(solid_blocks)
        for b in self.boss_bullets:
            b.update(solid_blocks)
        for c in self.coins:
            c.update()
        for g in self.gun_pickups:
            g.update()
        if self.gravity_pickup:
            self.gravity_pickup.update()
        for p in self.particles:
            p.update()
        self.particles = [p for p in self.particles if p.life > 0]

        for e in self.enemies:
            e.update(self.solids, self.spikes, self.player, self.enemy_bullets)
        self.enemies = [e for e in self.enemies if e.alive]

        if self.boss and self.boss.alive:
            self.boss.update(self.solids, self.player, self.boss_bullets,
                             self.enemies, self.spawn_extra_enemy_near_boss)

        alive_bullets = []
        for b in self.bullets:
            if not b.alive:
                continue
            hit = False
            for e in self.enemies:
                if rects_overlap((b.x - 2, b.y - 2, 4, 4), (e.x, e.y, e.w, e.h)):
                    e.hp -= b.dmg
                    e.hit_flash = 6
                    self.spawn_particles(e.x + e.w / 2, e.y + e.h / 2, 8, 5)
                    if e.hp <= 0:
                        e.alive = False
                        self.score += 100
                        self.coins_wallet += 2
                        self.spawn_particles(e.x + e.w / 2, e.y + e.h / 2, 11, 16)
                        pyxel.play(1, SND_ENEMY_DEATH)
                    hit = True
                    break
            if not hit and self.boss and self.boss.alive:
                if rects_overlap((b.x - 2, b.y - 2, 4, 4),
                                 (self.boss.x, self.boss.y, self.boss.w, self.boss.h)):
                    self.boss.take_hit(b.dmg)
                    self.spawn_particles(self.boss.x + self.boss.w / 2,
                                        self.boss.y + self.boss.h / 2, 8, 5)
                    hit = True
                    if not self.boss.alive:
                        self.score += 600
                        self.coins_wallet += 10
                        self.spawn_particles(self.boss.x + self.boss.w / 2,
                                             self.boss.y + self.boss.h / 2, 10, 40)
                        pyxel.play(1, SND_ENEMY_DEATH)
            if not hit:
                alive_bullets.append(b)
        self.bullets = alive_bullets

        alive_rockets = []
        for r in self.rockets:
            if not r.alive:
                continue
            if not r.exploded:
                for e in self.enemies:
                    if rects_overlap((r.x - 4, r.y - 2, 8, 4), (e.x, e.y, e.w, e.h)):
                        r.explode()
                        break
                if self.boss and self.boss.alive and not r.exploded:
                    if rects_overlap((r.x - 4, r.y - 2, 8, 4),
                                     (self.boss.x, self.boss.y, self.boss.w, self.boss.h)):
                        r.explode()
            if r.exploded:
                ex = r.explosion_rect()
                if ex and r.timer == 15:
                    for e in self.enemies:
                        if rects_overlap(ex, (e.x, e.y, e.w, e.h)):
                            e.hp -= r.dmg
                            e.hit_flash = 10
                            if e.hp <= 0:
                                e.alive = False
                                self.score += 100
                                self.coins_wallet += 2
                                self.spawn_particles(e.x + e.w / 2, e.y + e.h / 2, 9, 18)
                                pyxel.play(1, SND_ENEMY_DEATH)
                    if self.boss and self.boss.alive and rects_overlap(
                            ex, (self.boss.x, self.boss.y, self.boss.w, self.boss.h)):
                        self.boss.take_hit(r.dmg)
                        if not self.boss.alive:
                            self.score += 600
                            self.coins_wallet += 10
                            self.spawn_particles(self.boss.x + self.boss.w / 2,
                                                 self.boss.y + self.boss.h / 2, 10, 40)
                            pyxel.play(1, SND_ENEMY_DEATH)
            if r.alive:
                alive_rockets.append(r)
        self.rockets = alive_rockets
        self.enemies = [e for e in self.enemies if e.alive]

        new_enemy_bullets = []
        for b in self.enemy_bullets:
            if not b.alive:
                continue
            if rects_overlap((b.x - 2, b.y - 2, 4, 4),
                             (self.player.x, self.player.y, self.player.w, self.player.h)):
                self.player.take_damage(14)
                self.spawn_particles(self.player.x + self.player.w / 2,
                                     self.player.y + self.player.h / 2, 8, 8)
            else:
                new_enemy_bullets.append(b)
        self.enemy_bullets = new_enemy_bullets

        new_boss_bullets = []
        for b in self.boss_bullets:
            if not b.alive:
                continue
            if rects_overlap((b.x - 2, b.y - 2, 4, 4),
                             (self.player.x, self.player.y, self.player.w, self.player.h)):
                self.player.take_damage(18)
                self.spawn_particles(self.player.x + self.player.w / 2,
                                     self.player.y + self.player.h / 2, 8, 10)
            else:
                new_boss_bullets.append(b)
        self.boss_bullets = new_boss_bullets

        for e in self.enemies:
            if rects_overlap((self.player.x, self.player.y, self.player.w, self.player.h),
                             (e.x, e.y, e.w, e.h)):
                self.player.take_damage(e.contact_damage())
                self.spawn_particles(self.player.x + 8, self.player.y + 8, 14, 8)
                pyxel.play(1, SND_ENEMY_ATTACK)
            atk = e.sword_attack_rect()
            if atk and rects_overlap(atk, (self.player.x, self.player.y,
                                           self.player.w, self.player.h)):
                self.player.take_damage(28)
                self.spawn_particles(self.player.x + 8, self.player.y + 8, 10, 8)
                pyxel.play(1, SND_ENEMY_ATTACK)

        if self.boss and self.boss.alive:
            if rects_overlap((self.player.x, self.player.y, self.player.w, self.player.h),
                             (self.boss.x, self.boss.y, self.boss.w, self.boss.h)):
                self.player.take_damage(22)
                self.spawn_particles(self.player.x + 8, self.player.y + 8, 14, 10)

        player_rect = (self.player.x, self.player.y, self.player.w, self.player.h)
        if self.spike_hit(player_rect):
            self.player.take_damage(30)
            self.spawn_particles(self.player.x + 8, self.player.y + 10, 8, 10)
        if self.top_spike_hit(player_rect):
            self.player.take_damage(25)
            self.spawn_particles(self.player.x + 8, self.player.y, 9, 10)
        if self.lava_hit(player_rect):
            self.player.take_damage(40)
            self.spawn_particles(self.player.x + 8, self.player.y + 14, 8, 12)
        if self.crusher_hit(player_rect):
            self.player.take_damage(50)
            self.spawn_particles(self.player.x + 8, self.player.y + 5, 9, 14)

        for c in self.coins:
            if (not c.collected
                    and abs((self.player.x + self.player.w / 2) - c.x) < 12
                    and abs((self.player.y + self.player.h / 2) - c.y) < 12):
                c.collected = True
                self.score += 10
                self.coins_wallet += 1
                self.spawn_particles(c.x, c.y, 10, 6)

        for g in self.gun_pickups:
            if (not g.collected
                    and abs((self.player.x + self.player.w / 2) - g.x) < 18
                    and abs((self.player.y + self.player.h / 2) - g.y) < 18):
                g.collected = True
                gun_idx = g.gun_idx
                self.player.unlocked_guns.add(gun_idx)
                self.player.gun_ammo[gun_idx] = GUN_TYPES[gun_idx]["ammo"]
                self.score += 50
                self.coins_wallet += 4
                self.spawn_particles(g.x, g.y, GUN_TYPES[gun_idx]["col"], 16)
                self.gun_notify_text = f"PICKED UP: {GUN_TYPES[gun_idx]['full_name']}!"
                self.gun_notify_timer = 180

        if self.gravity_pickup and (not self.gravity_pickup.collected):
            if (abs((self.player.x + self.player.w / 2) - self.gravity_pickup.x) < 18
                    and abs((self.player.y + self.player.h / 2) - self.gravity_pickup.y) < 18):
                self.gravity_pickup.collected = True
                self.player.has_gravity_gun = True
                self.has_gravity_gun = True
                self.score += 80
                self.coins_wallet += 5
                self.spawn_particles(self.gravity_pickup.x, self.gravity_pickup.y, 13, 22)
                self.notify("GRAVITY GUN PICKED UP", 160)

        for cp in self.checkpoints:
            if not cp.active and abs((self.player.x + self.player.w / 2) - cp.x) < 14:
                cp.active = True
                self.last_checkpoint_x = cp.x
                self.last_checkpoint_y = cp.y
                self.spawn_particles(cp.x, cp.y, 11, 10)

        if self.player.dead and self.player.death_timer <= 0:
            self.save_player_state()
            self.lose_life()
            return

        can_use_portal = True
        if self.is_boss_level and self.boss and self.boss.alive:
            can_use_portal = False

        if not self.player.dead and can_use_portal and self.check_portal_enter():
            self.save_player_state()
            self.total_score += self.score
            self.level_no += 1
            if self.level_no > LEVEL_COUNT_FOR_WIN:
                self.state = "win"
                pyxel.play(0, SND_WIN_MELODY)
            else:
                pyxel.play(0, SND_LEVEL_COMPLETE)
                self.load_level()
                return

        target_cam = self.player.x - WIDTH // 3
        self.cam_x += (target_cam - self.cam_x) * 0.12
        self.cam_x = clamp(self.cam_x, 0, self.world_width - WIDTH)

    def draw_background(self):
        bg = self.bg_color
        if bg == 0:
            pyxel.cls(0)
            for i in range(24):
                sx = (i * 173 + int(self.cam_x * 0.03)) % WIDTH
                sy = (i * 89 + 10) % (HEIGHT - 20)
                pyxel.pset(sx, sy, 1 if i % 3 else 2)
        elif bg == 1:
            pyxel.cls(1)
            for i in range(12):
                bx = (i * 150 - int(self.cam_x * 0.05)) % (WIDTH + 60) - 20
                pyxel.tri(bx, 0, bx + 40, 40, bx + 80, 0, 2)
        else:
            pyxel.cls(5)
            for i in range(10):
                bx = (i * 130 - int(self.cam_x * 0.04)) % (WIDTH + 100) - 20
                bh = 26 + (i * 19) % 28
                pyxel.tri(bx, 0, bx + 28, bh, bx + 56, 0, 2)

        cx = int(self.cam_x)
        for i in range(10):
            bx = (i * 130 - cx // 6) % (WIDTH + 130) - 20
            bh = 28 + (i * 17) % 30
            pyxel.tri(bx, HEIGHT - 40, bx + 30, HEIGHT - 40 - bh, bx + 60, HEIGHT - 40,
                      2 if bg != 5 else 1)
        for i in range(8):
            bx = (i * 110 - cx // 3) % (WIDTH + 110) - 20
            bh = 30 + (i * 23) % 28
            pyxel.tri(bx, HEIGHT, bx + 22, HEIGHT - bh, bx + 44, HEIGHT,
                      1 if bg == 0 else 2)

    def draw_solids(self):
        ground_col = [4, 3, 4, 3, 2, 4, 3, 2, 4, 3][min(self.level_no - 1, 9)]
        top_col    = [9, 11, 9, 11, 8, 9, 11, 8, 9, 11][min(self.level_no - 1, 9)]

        for s in self.solids:
            px = int(s[0] - self.cam_x)
            py = int(s[1])
            pw = s[2]
            ph = s[3]
            if px + pw < 0 or px > WIDTH:
                continue

            if tuple(s) in self.pipe_rect_set:
                if pw == 20:
                    pyxel.rect(px, py, pw, ph, 3)
                    pyxel.rectb(px, py, pw, ph, 11)
                else:
                    pyxel.rect(px, py, pw, ph, 11)
                    pyxel.rectb(px, py, pw, ph, 3)
                    pyxel.rect(px + 1, py, pw - 2, 3, 3)
            else:
                pyxel.rect(px, py, pw, ph, ground_col)
                pyxel.rect(px, py, pw, 4, top_col)
                pyxel.rect(px, py, 2, ph, 1)
                pyxel.rect(px + pw - 2, py, 2, ph, 1)
                pyxel.rectb(px, py, pw, ph, 2)
                if ph > 24:
                    for crack_y in range(py + 14, py + ph - 4, 14):
                        pyxel.rect(px + 2, crack_y, max(1, pw - 4), 1, 1)

        for s in self.spikes:
            px = int(s[0] - self.cam_x)
            py = int(s[1])
            w = s[2]
            if px + w < 0 or px > WIDTH:
                continue
            for i in range(0, w, 8):
                pyxel.rect(px + i, py + 4, 8, 4, 4)
                pyxel.tri(px + i, py + 8, px + i + 4, py, px + i + 8, py + 8, 9)
                pyxel.pset(px + i + 4, py + 2, 10)

        for ts in self.top_spikes:
            ts.draw(self.cam_x)
        for mp in self.moving_platforms:
            mp.draw(self.cam_x)
        for lp in self.lava_pools:
            lp.draw(self.cam_x)
        for cr in self.crushers:
            cr.draw(self.cam_x)

    def draw_hud(self):
        pyxel.rect(8, 8, 90, 10, 0)
        hp_w = int((max(0, self.player.hp) / self.player.max_hp) * 86)
        hp_col = 11 if self.player.hp > 50 else (10 if self.player.hp > 25 else 8)
        pyxel.rect(10, 9, hp_w, 8, hp_col)
        pyxel.rectb(8, 8, 90, 10, 7)
        pyxel.text(11, 10, "HP", 7)

        pyxel.text(8, 22, f"LVL {self.level_no}/{LEVEL_COUNT_FOR_WIN}", 7)
        pyxel.text(8, 30, f"LIVES {self.lives_left}", 8 if self.lives_left <= 2 else 11)
        pyxel.text(8, 38, f"COINS {self.coins_wallet}", 10)
        if self.god_mode:
            gm_col = 10 if (pyxel.frame_count // 15) % 2 == 0 else 9
            pyxel.rect(8, 46, 38, 8, 0)
            pyxel.rectb(8, 46, 38, 8, gm_col)
            pyxel.text(10, 48, "GOD MODE", gm_col)

        gcd = self.player.grav_timer
        gbar = int((1 - gcd / max(1, self.player.current_grav_cooldown())) * 36) if gcd > 0 else 36
        pyxel.rect(8, 56, 36, 4, 0)
        pyxel.rect(8, 56, gbar, 4, 13 if gcd == 0 else 5)
        pyxel.rectb(8, 56, 36, 4, 7)
        pyxel.text(48, 54, "BH" if self.player.has_gravity_gun else "--",
                   13 if gcd == 0 else 5)

        if self.is_boss_level and self.boss and self.boss.alive:
            hint = "DEFEAT BOSS TO OPEN PORTAL!"
            pyxel.rect(WIDTH // 2 - len(hint) * 2 - 2, 68, len(hint) * 4 + 4, 9, 0)
            pyxel.rectb(WIDTH // 2 - len(hint) * 2 - 2, 68, len(hint) * 4 + 4, 9, 8)
            pyxel.text(WIDTH // 2 - len(hint) * 2, 70, hint, 8)

        if self.level_portal and self.portal_enter_timer >= 60:
            portal_screen_x = self.level_portal.x - self.cam_x
            if 0 < portal_screen_x < WIDTH:
                dist = abs((self.player.x + self.player.w / 2) - self.level_portal.x)
                if dist < 80:
                    col = 10 if self.level_portal.is_final else 13
                    lbl = ("ENTER FINISH PORTAL!" if self.level_portal.is_final
                           else f"ENTER LEVEL {self.level_no + 1} PORTAL!")
                    pyxel.text(WIDTH // 2 - len(lbl) * 2, HEIGHT - 24, lbl, col)

        panel_x = WIDTH - 118
        panel_y = 8
        panel_h = 14 + len(GUN_TYPES) * 13 + 28
        pyxel.rect(panel_x, panel_y, 110, panel_h, 0)
        pyxel.rectb(panel_x, panel_y, 110, panel_h, 5)
        pyxel.text(panel_x + 6, panel_y + 4,
                   KITTY_NAMES[self.selected_skin], KITTY_ACCENT[self.selected_skin])
        pyxel.text(panel_x + 52, panel_y + 4,
                   KITTY_TRAITS[self.selected_skin]["name"], 7)

        for i, gun in enumerate(GUN_TYPES):
            gy = panel_y + 14 + i * 13
            selected = i == self.player.gun_idx
            unlocked = i in self.player.unlocked_guns
            if not unlocked:
                pyxel.rect(panel_x + 4, gy, 102, 11, 0)
                pyxel.rectb(panel_x + 4, gy, 102, 11, 1)
                pyxel.text(panel_x + 6, gy + 2, f"[{gun['full_name']}]", 5)
                continue
            bg_col = 1 if selected else 0
            border_col = gun["col"] if selected else 5
            pyxel.rect(panel_x + 4, gy, 102, 11, bg_col)
            pyxel.rectb(panel_x + 4, gy, 102, 11, border_col)
            name_col = 15 if selected else gun["col"]
            pyxel.text(panel_x + 6, gy + 2, gun["full_name"], name_col)
            ammo = self.player.gun_ammo[i]
            if ammo < 0:
                txt = "INF"; acol = 11
            elif ammo == 0:
                txt = "0"; acol = 8
            else:
                txt = str(ammo)
                acol = 10 if ammo > 5 else 9
            pyxel.text(panel_x + 96 - len(txt) * 4, gy + 2, txt, acol)

        uy = panel_y + 14 + len(GUN_TYPES) * 13 + 2
        pyxel.text(panel_x + 6,  uy, f"STR {self.upgrades['strength']}", 8)
        pyxel.text(panel_x + 44, uy, f"SPD {self.upgrades['speed']}", 12)
        pyxel.text(panel_x + 82, uy, f"JMP {self.upgrades['jump']}", 11)
        pyxel.text(panel_x + 6, uy + 10, "TAB SHOP  Q/E GUN", 7)

        if self.gun_notify_timer > 0 and self.gun_notify_timer % 8 < 5:
            msg = self.gun_notify_text
            mw = len(msg) * 4 + 6
            mx = WIDTH // 2 - mw // 2
            pyxel.rect(mx, 70, mw, 11, 0)
            pyxel.rectb(mx, 70, mw, 11, 11)
            pyxel.text(mx + 3, 73, msg, 11)

        if self.shop_msg_timer > 0:
            msg = self.shop_msg
            mw = len(msg) * 4 + 6
            mx = WIDTH // 2 - mw // 2
            pyxel.rect(mx, 84, mw, 11, 0)
            pyxel.rectb(mx, 84, mw, 11, 14)
            pyxel.text(mx + 3, 87, msg, 14)

    def draw_shop(self):
        for yy in range(0, HEIGHT, 2):
            pyxel.rect(0, yy, WIDTH, 1, 0)

        pw = 380
        ph = 210
        px = WIDTH // 2 - pw // 2
        py = HEIGHT // 2 - ph // 2

        glow_col = 14 if (self.shop_anim // 10) % 2 == 0 else 13
        pyxel.rect(px - 3, py - 3, pw + 6, ph + 6, glow_col)
        pyxel.rect(px - 2, py - 2, pw + 4, ph + 4, 8)
        pyxel.rect(px, py, pw, ph, 0)
        pyxel.rectb(px, py, pw, ph, 13)

        title = "~ POTION SHOP ~"
        pyxel.text(WIDTH // 2 - len(title) * 2, py + 8, title, 14)
        pyxel.text(WIDTH // 2 - len(title) * 2 - 1, py + 8, title, 13)

        coins_txt = f"COINS: {self.coins_wallet}"
        pyxel.text(WIDTH // 2 - len(coins_txt) * 2, py + 18, coins_txt, 10)

        pyxel.rect(px + 10, py + 28, pw - 20, 1, 5)

        potion_data = [
            ("POWER",     8,  self.upgrades["strength"], SHOP_STRENGTH_COST + self.upgrades["strength"] * 8, "1"),
            ("SWIFTNESS", 12, self.upgrades["speed"],    SHOP_SPEED_COST    + self.upgrades["speed"]    * 7, "2"),
            ("LEAP",      11, self.upgrades["jump"],     SHOP_JUMP_COST     + self.upgrades["jump"]     * 7, "3"),
            ("LIFE",      8,  min(self.lives_left, 5),   SHOP_HEART_COST    + max(0, self.lives_left - 1) * 2, "4"),
            ("AMMO",      10, -1,                        SHOP_AMMO_COST,                                      "5"),
        ]

        slot_w = (pw - 36) // 5
        for i, (name, col, level, cost, key) in enumerate(potion_data):
            bx = px + 10 + i * (slot_w + 4)
            by = py + 36
            slot_col = 1 if self.shop_anim % 50 < 25 and i == (self.shop_anim // 50) % 5 else 0
            pyxel.rect(bx, by, slot_w, 130, slot_col)
            pyxel.rectb(bx, by, slot_w, 130, col)
            pyxel.text(bx + 4, by + 4, f"[{key}]", 15)

            if name == "AMMO":
                cx2 = bx + slot_w // 2
                cy2 = by + 48
                pyxel.rect(cx2 - 12, cy2 - 10, 24, 20, 4)
                pyxel.rectb(cx2 - 12, cy2 - 10, 24, 20, 9)
                pyxel.rect(cx2 - 12, cy2 - 2, 24, 4, 9)
                pyxel.rect(cx2 - 2, cy2 - 10, 4, 20, 9)
                if self.shop_anim % 20 < 10:
                    pyxel.pset(cx2 - 5, cy2 - 4, 15)
                    pyxel.pset(cx2 + 4, cy2 + 3, 15)
                pyxel.text(bx + 2, by + 62, "AMMO", col)
                pyxel.text(bx + 2, by + 70, f"{cost}c", 10)
                pyxel.text(bx + 2, by + 82, "REFILLS", 7)
                pyxel.text(bx + 2, by + 90, "ALL GUNS", 7)
                for bi in range(4):
                    pyxel.rect(bx + 4 + bi * 10, by + 104, 4, 8, col)
                    pyxel.pset(bx + 5 + bi * 10, by + 103, 15)
            else:
                draw_potion(bx + slot_w // 2 - 11, by + 16, col,
                            self.shop_anim + i * 15, name, level, f"{cost}c")
                dot_y = by + 94
                for d in range(5):
                    dot_col = col if d < level else 1
                    pyxel.circ(bx + 6 + d * 10, dot_y, 3, dot_col)
                    pyxel.circb(bx + 6 + d * 10, dot_y, 3, 7)
                descs = ["Damage+", "Speed+", "Jump+", "+1 Life"]
                pyxel.text(bx + 2, by + 108, descs[i], 7)

        pyxel.text(px + 12, py + ph - 16, "TAB/B CLOSE   1-5 BUY", 7)
        if self.shop_anim % 30 < 15:
            pyxel.rectb(px - 1, py - 1, pw + 2, ph + 2, 15)

    def draw_skin_select(self):
        pyxel.cls(0)
        for i in range(40):
            pyxel.pset((i * 173 + 7) % WIDTH, (i * 97 + 13) % HEIGHT, [1, 2, 5][i % 3])

        title1 = "KITTY SLAYER"
        title2 = "CHOOSE YOUR HERO"
        for dx in [-1, 0, 1]:
            pyxel.text(WIDTH // 2 - len(title1) * 2 + dx, 14, title1, 13 if dx != 0 else 14)
        pyxel.text(WIDTH // 2 - len(title2) * 2, 26, title2, 15)

        skin_names = ["pink", "dark", "red"]
        start_x = 28
        for i in range(3):
            cx = start_x + i * 148
            cy = 48
            selected = i == self.selected_skin
            card_col   = 0 if selected else 1
            border_col = KITTY_ACCENT[i] if selected else 5
            pyxel.rect(cx, cy, 138, 160, card_col)
            pyxel.rectb(cx, cy, 138, 160, border_col)
            if selected:
                pyxel.rectb(cx - 1, cy - 1, 140, 162, 14)
                pyxel.rectb(cx - 2, cy - 2, 142, 164, 13)

            sname = skin_names[i]
            u, v, sw, sh, colkey = SPRITE_DEST[sname]
            try:
                pyxel.blt(cx + 69 - sw, cy + 22, 0, u, v, sw, sh, colkey, scale=2.0)
            except TypeError:
                pyxel.blt(cx + 69 - sw // 2, cy + 38, 0, u, v, sw, sh, colkey)

            pyxel.text(cx + 42, cy + 82, KITTY_NAMES[i], 15 if selected else KITTY_ACCENT[i])
            pyxel.text(cx + 22, cy + 96, KITTY_TRAITS[i]["name"], KITTY_ACCENT[i])
            pyxel.text(cx + 14, cy + 112, "SPECIAL START BONUS", 7)

            trait = KITTY_TRAITS[i]
            if trait["jump"]:
                pyxel.text(cx + 22, cy + 126, "JUMP  +1", 11)
            if trait["speed"] > 0:
                pyxel.text(cx + 22, cy + 126, "SPEED +0.5", 12)
            if trait["strength"]:
                pyxel.text(cx + 22, cy + 126, "STR   +1", 8)

        pyxel.text(WIDTH // 2 - 78, HEIGHT - 40, "A / D  OR  ARROWS  TO  CHOOSE", 7)
        pyxel.text(WIDTH // 2 - 80, HEIGHT - 26, "SPACE  OR  ENTER  TO  START", 15)

    def draw_dead(self):
        pyxel.cls(0)
        for i in range(20):
            pyxel.pset((i * 113 + 7) % WIDTH, (i * 71 + 5) % HEIGHT, 2)
        if self.game_over_full:
            pyxel.text(WIDTH // 2 - 46, HEIGHT // 2 - 24, "GAME OVER", 8)
            pyxel.text(WIDTH // 2 - 38, HEIGHT // 2 - 12, "NO LIVES LEFT", 9)
            pyxel.text(WIDTH // 2 - 56, HEIGHT // 2 + 4, f"FINAL SCORE: {self.total_score + self.score}", 7)
            pyxel.text(WIDTH // 2 - 86, HEIGHT // 2 + 22, "PRESS SPACE TO START OVER", 15)
        else:
            pyxel.text(WIDTH // 2 - 28, HEIGHT // 2 - 20, "YOU DIED", 8)
            pyxel.text(WIDTH // 2 - 32, HEIGHT // 2, f"LIVES LEFT: {self.lives_left}", 7)
            pyxel.text(WIDTH // 2 - 72, HEIGHT // 2 + 20, "PRESS SPACE TO RESPAWN", 15)

    def draw_controls(self):
        pyxel.cls(0)
        t = pyxel.frame_count
        for i in range(30):
            pyxel.pset((i * 173 + t // 3) % WIDTH, (i * 97 + t // 4) % HEIGHT,
                       [1, 2, 5][i % 3])

        cx = WIDTH // 2
        title = "CONTROLS"
        for dx in (-1, 0, 1):
            pyxel.text(cx - len(title) * 2 + dx, 12, title, 13 if dx else 14)

        accent = KITTY_ACCENT[self.selected_skin]
        rows = [
            ("MOVE",         "A / D  or  ARROW KEYS",   7),
            ("JUMP",         "W / UP / SPACE  (double-jump!)", 11),
            ("SHOOT",        "Z  (hold)",               10),
            ("GRAVITY GUN",  "X",                       13),
            ("SWITCH GUN",   "Q / E",                    9),
            ("OPEN SHOP",    "TAB / B",                 14),
        ]
        y = 30
        for label, value, col in rows:
            if label == "":
                y += 4
                continue
            pyxel.text(cx - 90, y, label + ":", accent)
            pyxel.text(cx - 20, y, value, col)
            y += 12

        if t % 30 < 20:
            msg = "PRESS SPACE TO START"
            pyxel.rect(cx - len(msg) * 2 - 2, HEIGHT - 20, len(msg) * 4 + 4, 9, 0)
            pyxel.rectb(cx - len(msg) * 2 - 2, HEIGHT - 20, len(msg) * 4 + 4, 9, 14)
            pyxel.text(cx - len(msg) * 2, HEIGHT - 18, msg, 15)

    def draw_win(self):
        t = pyxel.frame_count
        pyxel.cls(1)

        for i in range(60):
            sx = (i * 97 + t * (1 + i % 4)) % WIDTH
            sy = (i * 137 + t // 2) % HEIGHT
            col = [14, 13, 10, 15, 9, 11][i % 6]
            if (t + i * 5) % 16 < 8:
                pyxel.pset(sx, sy, col)

        cx2, cy2 = WIDTH // 2, 90
        for i in range(16):
            ang = math.radians(t * 1.2 + i * 22.5)
            r = 30 + (t // 3 + i * 12) % 55
            ex = cx2 + int(math.cos(ang) * r)
            ey = cy2 + int(math.sin(ang) * r * 0.55)
            if 0 <= ex < WIDTH and 0 <= ey < HEIGHT:
                col = [14, 15, 13, 10, 9][i % 5]
                pyxel.pset(ex, ey, col)

        text = "YOU WIN!"
        scale = 3
        char_step = 4 * scale + 2
        total_w = len(text) * char_step - 2
        tx = WIDTH // 2 - total_w // 2
        bob = int(math.sin(t * 0.06) * 4)
        ty = 52 + bob

        for ci, ch in enumerate(text):
            if ch == " ":
                continue
            bx2 = tx + ci * char_step
            for dy in range(scale):
                for dx in range(scale):
                    pyxel.text(bx2 + dx + 3, ty + dy + 4, ch, 2)
            for dy in range(scale):
                for dx in range(scale):
                    for gx, gy in ((-1, 0), (1, 0), (0, -1), (0, 1)):
                        pyxel.text(bx2 + dx + gx, ty + dy + gy, ch, 13)
            for dy in range(scale):
                for dx in range(scale):
                    pyxel.text(bx2 + dx, ty + dy, ch, 14)
            pyxel.text(bx2, ty, ch, 15)

        sub = "KITTY SLAYER COMPLETE!"
        sub_col = 15 if (t // 12) % 2 == 0 else 14
        pyxel.text(WIDTH // 2 - len(sub) * 2, ty + scale * 8 + 6, sub, sub_col)

        bw = 140
        bx3 = WIDTH // 2 - bw // 2
        by3 = ty + scale * 8 + 22
        pyxel.rect(bx3, by3, bw, 28, 0)
        pyxel.rectb(bx3, by3, bw, 28, 14)
        pyxel.text(bx3 + 8, by3 + 4,  f"FINAL SCORE: {self.total_score}", 15)
        pyxel.text(bx3 + 8, by3 + 14, f"COINS:       {self.coins_wallet}", 10)

        portal = LevelPortal(WIDTH // 2 - 14, by3 + 40, LEVEL_COUNT_FOR_WIN, is_final=True)
        portal.anim = t
        portal.draw(0)

        if t % 30 < 20:
            msg = "PRESS SPACE TO PLAY AGAIN"
            pyxel.text(WIDTH // 2 - len(msg) * 2, HEIGHT - 14, msg, 7)

    def draw(self):
        if self.state == "loading":
            draw_loading_screen(self.loading_timer)
            return
        if self.state == "skin_select":
            self.draw_skin_select()
            return
        if self.state == "controls":
            self.draw_controls()
            return
        if self.state == "dead":
            self.draw_dead()
            return
        if self.state == "win":
            self.draw_win()
            return

        self.draw_background()
        self.draw_solids()

        if self.level_portal:
            self.level_portal.draw(self.cam_x)

        for cp in self.checkpoints:
            cp.draw(self.cam_x)
        for c in self.coins:
            c.draw(self.cam_x)
        for g in self.gun_pickups:
            g.draw(self.cam_x)
        if self.gravity_pickup:
            self.gravity_pickup.draw(self.cam_x)
        for bh in self.black_holes:
            bh.draw(self.cam_x)
        for b in self.bullets:
            b.draw(self.cam_x)
        for r in self.rockets:
            r.draw(self.cam_x)
        for b in self.enemy_bullets:
            b.draw(self.cam_x)
        for b in self.boss_bullets:
            b.draw(self.cam_x)
        for e in self.enemies:
            e.draw(self.cam_x)
        if self.boss and self.boss.alive:
            self.boss.draw(self.cam_x)
        for cn in self.cannons:
            cn.draw(self.cam_x)
        for p in self.particles:
            p.draw(self.cam_x)
        self.player.draw(self.cam_x)

        self.draw_hud()

        if self.boss_intro_timer > 0 and pyxel.frame_count % 4 < 3:
            msg = "!! BOSS INCOMING !!"
            mw = len(msg) * 4 + 12
            mx = WIDTH // 2 - mw // 2
            pyxel.rect(mx, HEIGHT // 2 - 16, mw, 32, 0)
            pyxel.rectb(mx, HEIGHT // 2 - 16, mw, 32, 8)
            pyxel.rectb(mx - 1, HEIGHT // 2 - 17, mw + 2, 34, 9)
            pyxel.text(WIDTH // 2 - len(msg) * 2, HEIGHT // 2 - 4, msg, 8)

        if self.shop_open:
            self.draw_shop()


Game()