import math
from dataclasses import dataclass, field
from typing import List, Optional

import pyxel


# =========================================================
# CONFIG
# =========================================================

SCREEN_W = 220
SCREEN_H = 140

GROUND_Y = 118
CEILING_Y = 20

PLAYER_SIZE = 10
PLAYER_X = 42

GRAVITY = 0.34
JUMP_POWER = -5.2
BALL_GRAVITY = 0.42
BALL_JUMP = -4.6

WAVE_SPEED = 3.2

STATE_MENU = "menu"
STATE_PLAY = "play"
STATE_DEAD = "dead"
STATE_WIN = "win"

MODE_CUBE = "cube"
MODE_WAVE = "wave"
MODE_BALL = "ball"


# =========================================================
# HELPERS
# =========================================================

def clamp(value, low, high):
    return max(low, min(high, value))


def rects_overlap(ax, ay, aw, ah, bx, by, bw, bh):
    return ax < bx + bw and ax + aw > bx and ay < by + bh and ay + ah > by


# =========================================================
# GAME OBJECTS
# =========================================================

@dataclass
class Hazard:
    x: float
    y: float
    w: int
    h: int
    kind: str = "block"

    def draw(self, scroll_x: float):
        sx = int(self.x - scroll_x)
        sy = int(self.y)

        if sx + self.w < -2 or sx > SCREEN_W + 2:
            return

        if self.kind == "block":
            pyxel.rect(sx, sy, self.w, self.h, 1)
            pyxel.rect(sx + 1, sy + 1, max(0, self.w - 2), max(0, self.h - 2), 13)
            pyxel.rect(sx + 2, sy + 2, max(0, self.w - 4), max(0, self.h - 4), 5)
            pyxel.rectb(sx, sy, self.w, self.h, 7)

        elif self.kind == "spike_up":
            pyxel.tri(sx, sy + self.h, sx + self.w // 2, sy, sx + self.w, sy + self.h, 8)
            pyxel.tri(sx + 1, sy + self.h, sx + self.w // 2, sy + 2, sx + self.w - 1, sy + self.h, 10)
            pyxel.line(sx + 2, sy + self.h - 1, sx + self.w // 2, sy + 3, 7)
            pyxel.line(sx + self.w - 2, sy + self.h - 1, sx + self.w // 2, sy + 3, 7)

        elif self.kind == "spike_down":
            pyxel.tri(sx, sy, sx + self.w // 2, sy + self.h, sx + self.w, sy, 8)
            pyxel.tri(sx + 1, sy, sx + self.w // 2, sy + self.h - 2, sx + self.w - 1, sy, 10)
            pyxel.line(sx + 2, sy + 1, sx + self.w // 2, sy + self.h - 3, 7)
            pyxel.line(sx + self.w - 2, sy + 1, sx + self.w // 2, sy + self.h - 3, 7)

        elif self.kind == "spike_left":
            pyxel.tri(sx + self.w, sy, sx, sy + self.h // 2, sx + self.w, sy + self.h, 8)
            pyxel.tri(sx + self.w - 1, sy + 1, sx + 2, sy + self.h // 2, sx + self.w - 1, sy + self.h - 1, 10)

        elif self.kind == "spike_right":
            pyxel.tri(sx, sy, sx + self.w, sy + self.h // 2, sx, sy + self.h, 8)
            pyxel.tri(sx + 1, sy + 1, sx + self.w - 2, sy + self.h // 2, sx + 1, sy + self.h - 1, 10)


@dataclass
class Orb:
    x: float
    y: float
    orb_type: str
    used: bool = False
    anim: int = 0

    def update(self):
        self.anim = (self.anim + 1) % 60

    def draw(self, scroll_x: float):
        if self.used:
            return

        sx = int(self.x - scroll_x)
        sy = int(self.y)

        if sx < -12 or sx > SCREEN_W + 12:
            return

        pulse = 1 if self.anim < 30 else 0

        if self.orb_type == "jump":
            main, glow = 10, 9
        elif self.orb_type == "gravity":
            main, glow = 12, 6
        elif self.orb_type == "dash":
            main, glow = 11, 7
        else:
            main, glow = 14, 8

        pyxel.circ(sx, sy, 5 + pulse, glow)
        pyxel.circ(sx, sy, 4, main)
        pyxel.circb(sx, sy, 5, 7)

        if self.orb_type == "jump":
            pyxel.line(sx, sy + 2, sx, sy - 2, 7)
            pyxel.line(sx, sy - 2, sx - 2, sy, 7)
            pyxel.line(sx, sy - 2, sx + 2, sy, 7)
        elif self.orb_type == "gravity":
            pyxel.line(sx - 2, sy - 2, sx + 2, sy - 2, 7)
            pyxel.line(sx - 2, sy + 2, sx + 2, sy + 2, 7)
            pyxel.pset(sx - 1, sy, 7)
            pyxel.pset(sx + 1, sy, 7)
        elif self.orb_type == "dash":
            pyxel.line(sx - 2, sy, sx + 2, sy, 7)
            pyxel.line(sx + 2, sy, sx, sy - 2, 7)
            pyxel.line(sx + 2, sy, sx, sy + 2, 7)


@dataclass
class Portal:
    x: float
    y: float
    w: int
    h: int
    target_mode: str
    activated: bool = False
    anim: int = 0

    def update(self):
        self.anim = (self.anim + 1) % 40

    def draw(self, scroll_x: float):
        sx = int(self.x - scroll_x)
        sy = int(self.y)

        if sx + self.w < -10 or sx > SCREEN_W + 10:
            return

        pulse = 1 if self.anim < 20 else 0

        if self.target_mode == MODE_WAVE:
            color_main = 12
            color_glow = 5
        elif self.target_mode == MODE_BALL:
            color_main = 9
            color_glow = 14
        else:
            color_main = 10
            color_glow = 9

        pyxel.circb(sx + self.w // 2, sy + self.h // 2, 8 + pulse, color_glow)
        pyxel.rect(sx, sy, self.w, self.h, color_main)
        pyxel.rectb(sx, sy, self.w, self.h, 7)

        if self.target_mode == MODE_WAVE:
            pyxel.line(sx + 3, sy + self.h - 4, sx + self.w // 2, sy + 4, 7)
            pyxel.line(sx + self.w // 2, sy + 4, sx + self.w - 3, sy + self.h - 4, 7)
        elif self.target_mode == MODE_BALL:
            pyxel.circ(sx + self.w // 2, sy + self.h // 2, 4, 7)
            pyxel.circb(sx + self.w // 2, sy + self.h // 2, 4, color_glow)
        else:
            pyxel.rect(sx + 4, sy + 4, self.w - 8, self.h - 8, 7)


@dataclass
class DecoBlock:
    x: float
    y: float
    w: int
    h: int
    color: int

    def draw(self, scroll_x: float, parallax: float):
        sx = int(self.x - scroll_x * parallax)
        sy = int(self.y)

        if sx + self.w < 0 or sx > SCREEN_W:
            return

        pyxel.rect(sx, sy, self.w, self.h, self.color)
        pyxel.rectb(sx, sy, self.w, self.h, 1)


@dataclass
class Level:
    name: str
    speed: float
    length: int
    hazards: List[Hazard] = field(default_factory=list)
    orbs: List[Orb] = field(default_factory=list)
    portals: List[Portal] = field(default_factory=list)
    deco: List[DecoBlock] = field(default_factory=list)
    difficulty: str = "Easy"
    bg_color: int = 0

    def clone(self):
        return Level(
            name=self.name,
            speed=self.speed,
            length=self.length,
            hazards=[Hazard(h.x, h.y, h.w, h.h, h.kind) for h in self.hazards],
            orbs=[Orb(o.x, o.y, o.orb_type) for o in self.orbs],
            portals=[Portal(p.x, p.y, p.w, p.h, p.target_mode) for p in self.portals],
            deco=[DecoBlock(d.x, d.y, d.w, d.h, d.color) for d in self.deco],
            difficulty=self.difficulty,
            bg_color=self.bg_color,
        )


# =========================================================
# PLAYER
# =========================================================

class Player:
    def __init__(self):
        self.reset()

    def reset(self):
        self.x = float(PLAYER_X)
        self.y = float(GROUND_Y - PLAYER_SIZE)
        self.vy = 0.0
        self.gravity_dir = 1
        self.on_surface = True
        self.alive = True
        self.rotation_frame = 0
        self.dash_timer = 0
        self.mode = MODE_CUBE
        self.trail = []
        self.ball_angle = 0.0

    def jump(self):
        if not self.alive:
            return
        if self.mode == MODE_CUBE and self.on_surface:
            self.vy = JUMP_POWER * self.gravity_dir
            self.on_surface = False
        elif self.mode == MODE_BALL:
            # Ball can jump anytime (bouncy mechanic)
            self.vy = BALL_JUMP * self.gravity_dir
            self.on_surface = False

    def flip_gravity(self):
        if not self.alive or self.mode not in (MODE_CUBE, MODE_BALL):
            return
        self.gravity_dir *= -1
        self.vy = 0
        self.on_surface = False

    def dash(self):
        if self.alive:
            self.dash_timer = 9

    def set_mode(self, mode: str):
        self.mode = mode
        if mode == MODE_CUBE:
            self.vy = 0
            self.y = clamp(self.y, CEILING_Y, GROUND_Y - PLAYER_SIZE)
        elif mode == MODE_BALL:
            self.vy = 0
            self.y = clamp(self.y, CEILING_Y, GROUND_Y - PLAYER_SIZE)
        else:
            self.vy = 0
            self.on_surface = False

    def update(self, hold_input: bool, pressed: bool):
        if not self.alive:
            return

        if self.dash_timer > 0:
            self.dash_timer -= 1

        if self.mode == MODE_CUBE:
            self.vy += GRAVITY * self.gravity_dir
            self.y += self.vy

            if self.gravity_dir == 1:
                if self.y + PLAYER_SIZE >= GROUND_Y:
                    self.y = GROUND_Y - PLAYER_SIZE
                    self.vy = 0
                    self.on_surface = True
                else:
                    self.on_surface = False
                if self.y < CEILING_Y:
                    self.y = CEILING_Y
                    self.vy = 0
            else:
                if self.y <= CEILING_Y:
                    self.y = CEILING_Y
                    self.vy = 0
                    self.on_surface = True
                else:
                    self.on_surface = False
                if self.y + PLAYER_SIZE > GROUND_Y:
                    self.y = GROUND_Y - PLAYER_SIZE
                    self.vy = 0

        elif self.mode == MODE_BALL:
            self.vy += BALL_GRAVITY * self.gravity_dir
            self.vy = clamp(self.vy, -7.0, 7.0)
            self.y += self.vy
            self.ball_angle += 8.0 * self.gravity_dir

            if self.gravity_dir == 1:
                if self.y + PLAYER_SIZE >= GROUND_Y:
                    self.y = GROUND_Y - PLAYER_SIZE
                    self.vy = 0
                    self.on_surface = True
                else:
                    self.on_surface = False
                if self.y < CEILING_Y:
                    self.y = CEILING_Y
                    self.vy = abs(self.vy) * 0.6  # bounce off ceiling
            else:
                if self.y <= CEILING_Y:
                    self.y = CEILING_Y
                    self.vy = 0
                    self.on_surface = True
                else:
                    self.on_surface = False
                if self.y + PLAYER_SIZE > GROUND_Y:
                    self.y = GROUND_Y - PLAYER_SIZE
                    self.vy = -abs(self.vy) * 0.6  # bounce off floor

        elif self.mode == MODE_WAVE:
            if hold_input:
                self.y -= WAVE_SPEED
            else:
                self.y += WAVE_SPEED
            self.y = clamp(self.y, CEILING_Y, GROUND_Y - PLAYER_SIZE)

        self.rotation_frame = (self.rotation_frame + 1) % 40
        self.trail.append((int(self.x + PLAYER_SIZE // 2), int(self.y + PLAYER_SIZE // 2)))
        if len(self.trail) > 14:
            self.trail.pop(0)

    def get_hitbox(self):
        if self.mode == MODE_CUBE:
            return self.x + 1, self.y + 1, PLAYER_SIZE - 2, PLAYER_SIZE - 2
        elif self.mode == MODE_BALL:
            return self.x + 2, self.y + 2, PLAYER_SIZE - 4, PLAYER_SIZE - 4
        else:
            return self.x + 2, self.y + 2, PLAYER_SIZE - 4, PLAYER_SIZE - 4


# =========================================================
# LEVELS
# =========================================================

def create_levels():
    levels = []

    # -----------------------------------------------------
    # Level 1: Starter - nur Cube, sehr einfach
    # -----------------------------------------------------
    hazards_1 = [
        Hazard(120, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(132, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(210, GROUND_Y - 18, 18, 18, "block"),
        Hazard(300, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(312, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(420, CEILING_Y, 10, 10, "spike_down"),
        Hazard(432, CEILING_Y, 10, 10, "spike_down"),
        Hazard(540, GROUND_Y - 18, 18, 18, "block"),
        Hazard(620, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(632, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(760, CEILING_Y, 10, 10, "spike_down"),
        Hazard(772, CEILING_Y, 10, 10, "spike_down"),
    ]
    orbs_1 = [
        Orb(260, 76, "jump"),
        Orb(490, 60, "gravity"),
        Orb(710, 50, "gravity"),
    ]
    deco_1 = [
        DecoBlock(80, 34, 16, 24, 1),
        DecoBlock(170, 26, 22, 30, 2),
        DecoBlock(340, 30, 18, 22, 1),
        DecoBlock(580, 24, 28, 32, 2),
        DecoBlock(820, 28, 18, 18, 1),
    ]
    levels.append(Level("Starter", 2.1, 980, hazards_1, orbs_1, [], deco_1, difficulty="Easy", bg_color=0))

    # -----------------------------------------------------
    # Level 2: Wave Mix - Cube + Wave
    # -----------------------------------------------------
    hazards_2 = [
        Hazard(100, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(112, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(124, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(220, GROUND_Y - 18, 18, 18, "block"),
        Hazard(280, GROUND_Y - 18, 18, 18, "block"),
        Hazard(340, CEILING_Y, 10, 10, "spike_down"),
        Hazard(352, CEILING_Y, 10, 10, "spike_down"),
        # Wave-Tunnel
        Hazard(640, GROUND_Y - 20, 22, 20, "block"),
        Hazard(640, CEILING_Y, 22, 20, "block"),
        Hazard(720, GROUND_Y - 22, 24, 22, "block"),
        Hazard(720, CEILING_Y, 24, 22, "block"),
        Hazard(800, GROUND_Y - 18, 22, 18, "block"),
        Hazard(800, CEILING_Y, 22, 18, "block"),
        Hazard(890, GROUND_Y - 24, 24, 24, "block"),
        Hazard(890, CEILING_Y, 24, 24, "block"),
        # After wave
        Hazard(1090, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1102, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1114, GROUND_Y - 10, 10, 10, "spike_up"),
    ]
    orbs_2 = [
        Orb(300, 72, "jump"),
        Orb(460, 58, "gravity"),
        Orb(540, 52, "gravity"),
        Orb(1030, 70, "jump"),
    ]
    portals_2 = [
        Portal(580, 48, 14, 24, MODE_WAVE),
        Portal(980, 48, 14, 24, MODE_CUBE),
    ]
    deco_2 = [
        DecoBlock(70, 28, 20, 20, 1),
        DecoBlock(180, 22, 24, 28, 2),
        DecoBlock(400, 26, 16, 18, 1),
        DecoBlock(680, 24, 20, 24, 2),
        DecoBlock(920, 30, 18, 16, 1),
    ]
    levels.append(Level("Wave Mix", 2.25, 1280, hazards_2, orbs_2, portals_2, deco_2, difficulty="Normal", bg_color=1))

    # -----------------------------------------------------
    # Level 3: Ball - Einführung in den Ball-Modus
    # -----------------------------------------------------
    hazards_3 = [
        Hazard(100, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(112, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(200, GROUND_Y - 18, 18, 18, "block"),
        # Ball section obstacles - bouncy passages
        Hazard(480, CEILING_Y, 10, 14, "spike_down"),
        Hazard(492, CEILING_Y, 10, 14, "spike_down"),
        Hazard(560, GROUND_Y - 14, 10, 14, "spike_up"),
        Hazard(572, GROUND_Y - 14, 10, 14, "spike_up"),
        Hazard(640, CEILING_Y, 10, 16, "spike_down"),
        Hazard(652, CEILING_Y, 10, 16, "spike_down"),
        Hazard(720, GROUND_Y - 16, 10, 16, "spike_up"),
        Hazard(732, GROUND_Y - 16, 10, 16, "spike_up"),
        Hazard(800, CEILING_Y, 10, 14, "spike_down"),
        # Back to cube
        Hazard(1000, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1012, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1080, GROUND_Y - 18, 18, 18, "block"),
    ]
    orbs_3 = [
        Orb(260, 74, "jump"),
        Orb(870, 58, "jump"),
        Orb(940, 76, "jump"),
    ]
    portals_3 = [
        Portal(360, 48, 14, 24, MODE_BALL),
        Portal(930, 48, 14, 24, MODE_CUBE),
    ]
    deco_3 = [
        DecoBlock(60, 30, 18, 24, 2),
        DecoBlock(160, 24, 22, 28, 5),
        DecoBlock(440, 26, 16, 18, 2),
        DecoBlock(760, 28, 22, 22, 5),
        DecoBlock(1050, 24, 20, 26, 2),
    ]
    levels.append(Level("Ball", 2.2, 1200, hazards_3, orbs_3, portals_3, deco_3, difficulty="Normal", bg_color=2))

    # -----------------------------------------------------
    # Level 4: Gravity Run - flipped gravity + harder
    # -----------------------------------------------------
    hazards_4 = [
        Hazard(90, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(102, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(180, CEILING_Y, 10, 10, "spike_down"),
        Hazard(192, CEILING_Y, 10, 10, "spike_down"),
        Hazard(270, GROUND_Y - 18, 18, 18, "block"),
        Hazard(298, GROUND_Y - 18, 18, 18, "block"),
        Hazard(392, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(404, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(416, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(560, GROUND_Y - 22, 22, 22, "block"),
        Hazard(560, CEILING_Y, 22, 22, "block"),
        Hazard(650, GROUND_Y - 18, 18, 18, "block"),
        Hazard(650, CEILING_Y, 18, 18, "block"),
        Hazard(740, GROUND_Y - 22, 22, 22, "block"),
        Hazard(740, CEILING_Y, 22, 22, "block"),
        Hazard(930, CEILING_Y, 10, 10, "spike_down"),
        Hazard(942, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1030, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1042, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1054, GROUND_Y - 10, 10, 10, "spike_up"),
    ]
    orbs_4 = [
        Orb(240, 74, "jump"),
        Orb(500, 58, "gravity"),
        Orb(860, 60, "gravity"),
        Orb(980, 76, "jump"),
    ]
    portals_4 = [
        Portal(520, 48, 14, 24, MODE_WAVE),
        Portal(820, 48, 14, 24, MODE_CUBE),
    ]
    deco_4 = [
        DecoBlock(60, 30, 18, 24, 1),
        DecoBlock(210, 26, 26, 28, 2),
        DecoBlock(470, 24, 18, 18, 1),
        DecoBlock(720, 28, 22, 22, 2),
        DecoBlock(1000, 24, 20, 26, 1),
    ]
    levels.append(Level("Gravity Run", 2.35, 1160, hazards_4, orbs_4, portals_4, deco_4, difficulty="Hard", bg_color=0))

    # -----------------------------------------------------
    # Level 5: Triple Mix - Cube + Ball + Wave, schwerer
    # -----------------------------------------------------
    hazards_5 = [
        Hazard(80, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(92, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(160, GROUND_Y - 18, 18, 18, "block"),
        Hazard(200, CEILING_Y, 10, 10, "spike_down"),
        Hazard(212, CEILING_Y, 10, 10, "spike_down"),
        Hazard(290, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(302, GROUND_Y - 10, 10, 10, "spike_up"),
        # Ball section
        Hazard(420, CEILING_Y, 10, 14, "spike_down"),
        Hazard(432, CEILING_Y, 10, 14, "spike_down"),
        Hazard(500, GROUND_Y - 14, 10, 14, "spike_up"),
        Hazard(512, GROUND_Y - 14, 10, 14, "spike_up"),
        Hazard(580, CEILING_Y, 10, 16, "spike_down"),
        Hazard(592, CEILING_Y, 10, 16, "spike_down"),
        Hazard(660, GROUND_Y - 12, 10, 12, "spike_up"),
        Hazard(672, GROUND_Y - 12, 10, 12, "spike_up"),
        # Wave tight tunnels
        Hazard(800, GROUND_Y - 22, 22, 22, "block"),
        Hazard(800, CEILING_Y, 22, 22, "block"),
        Hazard(880, GROUND_Y - 18, 18, 18, "block"),
        Hazard(880, CEILING_Y, 18, 18, "block"),
        Hazard(960, GROUND_Y - 24, 24, 24, "block"),
        Hazard(960, CEILING_Y, 24, 24, "block"),
        Hazard(1040, GROUND_Y - 20, 20, 20, "block"),
        Hazard(1040, CEILING_Y, 20, 20, "block"),
        # End cube
        Hazard(1200, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1212, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1224, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1300, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1312, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1380, GROUND_Y - 18, 18, 18, "block"),
    ]
    orbs_5 = [
        Orb(240, 74, "jump"),
        Orb(360, 58, "gravity"),
        Orb(720, 60, "jump"),
        Orb(750, 42, "jump"),
        Orb(1150, 76, "gravity"),
        Orb(1260, 58, "jump"),
    ]
    portals_5 = [
        Portal(380, 48, 14, 24, MODE_BALL),
        Portal(740, 48, 14, 24, MODE_WAVE),
        Portal(1160, 48, 14, 24, MODE_CUBE),
    ]
    deco_5 = [
        DecoBlock(60, 28, 18, 22, 2),
        DecoBlock(140, 22, 22, 26, 5),
        DecoBlock(340, 26, 16, 18, 1),
        DecoBlock(700, 28, 20, 22, 2),
        DecoBlock(1120, 24, 18, 26, 5),
    ]
    levels.append(Level("Triple Mix", 2.5, 1480, hazards_5, orbs_5, portals_5, deco_5, difficulty="Hard", bg_color=1))

    # -----------------------------------------------------
    # Level 6: Nightmare - sehr schwer, alle Modi, schneller
    # -----------------------------------------------------
    hazards_6 = [
        Hazard(70, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(82, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(94, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(160, CEILING_Y, 10, 10, "spike_down"),
        Hazard(172, CEILING_Y, 10, 10, "spike_down"),
        Hazard(184, CEILING_Y, 10, 10, "spike_down"),
        Hazard(250, GROUND_Y - 18, 18, 18, "block"),
        Hazard(278, GROUND_Y - 18, 18, 18, "block"),
        # Tight ball section
        Hazard(400, CEILING_Y, 10, 18, "spike_down"),
        Hazard(412, CEILING_Y, 10, 18, "spike_down"),
        Hazard(424, CEILING_Y, 10, 18, "spike_down"),
        Hazard(460, GROUND_Y - 18, 10, 18, "spike_up"),
        Hazard(472, GROUND_Y - 18, 10, 18, "spike_up"),
        Hazard(520, CEILING_Y, 10, 16, "spike_down"),
        Hazard(532, CEILING_Y, 10, 16, "spike_down"),
        Hazard(560, GROUND_Y - 18, 10, 18, "spike_up"),
        Hazard(572, GROUND_Y - 18, 10, 18, "spike_up"),
        Hazard(620, CEILING_Y, 10, 18, "spike_down"),
        # Very tight wave tunnels
        Hazard(760, GROUND_Y - 16, 16, 16, "block"),
        Hazard(760, CEILING_Y, 16, 16, "block"),
        Hazard(820, GROUND_Y - 20, 20, 20, "block"),
        Hazard(820, CEILING_Y, 20, 20, "block"),
        Hazard(880, GROUND_Y - 16, 14, 16, "block"),
        Hazard(880, CEILING_Y, 14, 16, "block"),
        Hazard(940, GROUND_Y - 22, 22, 22, "block"),
        Hazard(940, CEILING_Y, 22, 22, "block"),
        Hazard(1000, GROUND_Y - 18, 16, 18, "block"),
        Hazard(1000, CEILING_Y, 16, 18, "block"),
        # Final cube gauntlet
        Hazard(1150, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1162, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1174, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1220, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1232, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1244, CEILING_Y, 10, 10, "spike_down"),
        Hazard(1300, GROUND_Y - 18, 18, 18, "block"),
        Hazard(1340, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1352, GROUND_Y - 10, 10, 10, "spike_up"),
        Hazard(1364, GROUND_Y - 10, 10, 10, "spike_up"),
    ]
    orbs_6 = [
        Orb(220, 74, "jump"),
        Orb(340, 60, "gravity"),
        Orb(660, 58, "jump"),
        Orb(690, 40, "jump"),
        Orb(1080, 60, "gravity"),
        Orb(1120, 80, "jump"),
        Orb(1280, 58, "gravity"),
    ]
    portals_6 = [
        Portal(360, 48, 14, 24, MODE_BALL),
        Portal(700, 48, 14, 24, MODE_WAVE),
        Portal(1100, 48, 14, 24, MODE_CUBE),
    ]
    deco_6 = [
        DecoBlock(50, 26, 16, 22, 2),
        DecoBlock(210, 22, 18, 24, 8),
        DecoBlock(330, 28, 14, 18, 2),
        DecoBlock(680, 24, 18, 20, 8),
        DecoBlock(1060, 26, 16, 22, 2),
    ]
    levels.append(Level("Nightmare", 2.7, 1480, hazards_6, orbs_6, portals_6, deco_6, difficulty="Extreme", bg_color=2))

    return levels


# =========================================================
# GAME
# =========================================================

class Game:
    def __init__(self):
        pyxel.init(SCREEN_W, SCREEN_H, title="Pixel Dash Wave", fps=60)

        self.state = STATE_MENU
        self.player = Player()
        self.level_templates = create_levels()
        self.level_index = 0
        self.level: Optional[Level] = None
        self.scroll_x = 0.0
        self.score = 0
        self.best_scores = [0 for _ in self.level_templates]
        self.flash_timer = 0
        self.death_reason = ""
        self.menu_anim = 0
        self.pause_requested = False

        pyxel.run(self.update, self.draw)

    # -----------------------------------------------------
    # flow
    # -----------------------------------------------------
    def load_level(self, index: int):
        if not self.level_templates:
            return

        index = clamp(index, 0, len(self.level_templates) - 1)
        self.level_index = index
        self.level = self.level_templates[index].clone()
        self.player.reset()
        self.scroll_x = 0.0
        self.score = 0
        self.flash_timer = 0
        self.death_reason = ""
        self.pause_requested = False
        self.state = STATE_PLAY

    def back_to_menu(self):
        self.state = STATE_MENU
        self.level = None
        self.player.reset()
        self.scroll_x = 0.0
        self.score = 0
        self.flash_timer = 0
        self.death_reason = ""
        self.pause_requested = False

    def kill_player(self, reason: str = ""):
        self.player.alive = False
        self.state = STATE_DEAD
        self.flash_timer = 5
        self.death_reason = reason
        if 0 <= self.level_index < len(self.best_scores):
            self.best_scores[self.level_index] = max(self.best_scores[self.level_index], self.score)

    # -----------------------------------------------------
    # input
    # -----------------------------------------------------
    def input_pressed(self):
        return (
            pyxel.btnp(pyxel.KEY_SPACE)
            or pyxel.btnp(pyxel.KEY_X)
            or pyxel.btnp(pyxel.KEY_UP)
            or pyxel.btnp(pyxel.MOUSE_BUTTON_LEFT)
        )

    def input_held(self):
        return (
            pyxel.btn(pyxel.KEY_SPACE)
            or pyxel.btn(pyxel.KEY_X)
            or pyxel.btn(pyxel.KEY_UP)
            or pyxel.btn(pyxel.MOUSE_BUTTON_LEFT)
        )

    def get_speed(self):
        if self.level is None:
            return 0.0
        speed = self.level.speed
        if self.player.dash_timer > 0:
            speed += 1.6
        return speed

    # -----------------------------------------------------
    # updates
    # -----------------------------------------------------
    def update_menu(self):
        self.menu_anim = (self.menu_anim + 1) % 120

        if not self.level_templates:
            return

        if pyxel.btnp(pyxel.KEY_RIGHT) or pyxel.btnp(pyxel.KEY_D):
            self.level_index = (self.level_index + 1) % len(self.level_templates)
        if pyxel.btnp(pyxel.KEY_LEFT) or pyxel.btnp(pyxel.KEY_A):
            self.level_index = (self.level_index - 1) % len(self.level_templates)
        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            self.load_level(self.level_index)

    def update_play(self):
        if self.level is None:
            self.back_to_menu()
            return

        # ESC gedrückt: direkt zurück ins Menü, Spiel wird NICHT beendet
        if pyxel.btnp(pyxel.KEY_ESCAPE):
            self.back_to_menu()
            return

        if pyxel.btnp(pyxel.KEY_R):
            self.load_level(self.level_index)
            return

        pressed = self.input_pressed()
        held = self.input_held()

        for orb in self.level.orbs:
            orb.update()

        for portal in self.level.portals:
            portal.update()

        # Portale
        for portal in self.level.portals:
            if portal.activated:
                continue
            sx = portal.x - self.scroll_x
            if rects_overlap(self.player.x, self.player.y, PLAYER_SIZE, PLAYER_SIZE, sx, portal.y, portal.w, portal.h):
                portal.activated = True
                self.player.set_mode(portal.target_mode)

        # Orbs
        orb_triggered = False
        if self.player.mode in (MODE_CUBE, MODE_BALL):
            for orb in self.level.orbs:
                if orb.used:
                    continue
                sx = orb.x - self.scroll_x
                if rects_overlap(self.player.x, self.player.y, PLAYER_SIZE, PLAYER_SIZE, sx - 5, orb.y - 5, 10, 10):
                    if pressed:
                        orb.used = True
                        orb_triggered = True
                        if orb.orb_type == "jump":
                            self.player.vy = JUMP_POWER * 1.38 * self.player.gravity_dir
                            self.player.on_surface = False
                        elif orb.orb_type == "gravity":
                            self.player.flip_gravity()
                        elif orb.orb_type == "dash":
                            self.player.dash()
                        break

        # Jump input
        if pressed and not orb_triggered:
            if self.player.mode == MODE_CUBE:
                self.player.jump()
            elif self.player.mode == MODE_BALL:
                self.player.flip_gravity()  # Ball flips gravity on press

        self.scroll_x += self.get_speed()
        self.score = int(self.scroll_x)
        self.player.update(held, pressed)

        # Wave stirbt bei Boden/Decke
        if self.player.mode == MODE_WAVE:
            if self.player.y <= CEILING_Y or self.player.y + PLAYER_SIZE >= GROUND_Y:
                self.kill_player("Wall hit")
                return

        # Kollisionen mit Hazards
        px, py, pw, ph = self.player.get_hitbox()
        for hazard in self.level.hazards:
            sx = hazard.x - self.scroll_x
            if rects_overlap(px, py, pw, ph, sx, hazard.y, hazard.w, hazard.h):
                self.kill_player("Hazard")
                return

        # Ziel
        if self.scroll_x >= self.level.length:
            self.state = STATE_WIN
            self.best_scores[self.level_index] = max(self.best_scores[self.level_index], self.score)

    def update_dead(self):
        if self.flash_timer > 0:
            self.flash_timer -= 1

        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_R):
            self.load_level(self.level_index)
        if pyxel.btnp(pyxel.KEY_ESCAPE):
            self.back_to_menu()

    def update_win(self):
        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_R):
            self.load_level(self.level_index)
        if pyxel.btnp(pyxel.KEY_ESCAPE):
            self.back_to_menu()
        # Auto advance nach langer Zeit wenn gewünscht

    def update(self):
        if self.state == STATE_MENU:
            self.update_menu()
        elif self.state == STATE_PLAY:
            self.update_play()
        elif self.state == STATE_DEAD:
            self.update_dead()
        elif self.state == STATE_WIN:
            self.update_win()

    # -----------------------------------------------------
    # draw helpers
    # -----------------------------------------------------
    def draw_background(self):
        bg = self.level.bg_color if self.level else 0
        pyxel.cls(bg)

        for i in range(0, SCREEN_W, 12):
            x1 = int((i - self.scroll_x * 0.20) % SCREEN_W)
            x2 = int((i * 2 - self.scroll_x * 0.35) % SCREEN_W)
            pyxel.pset(x1, 10 + (i % 18), 1)
            pyxel.pset(x2, 22 + (i % 14), 5)

        if self.level is not None:
            for deco in self.level.deco:
                deco.draw(self.scroll_x, 0.35)
            for deco in self.level.deco:
                deco.draw(self.scroll_x, 0.60)

        pyxel.rect(0, 0, SCREEN_W, CEILING_Y, 2)
        pyxel.rect(0, GROUND_Y, SCREEN_W, SCREEN_H - GROUND_Y, 3)
        pyxel.line(0, CEILING_Y, SCREEN_W, CEILING_Y, 7)
        pyxel.line(0, GROUND_Y, SCREEN_W, GROUND_Y, 7)

        for x in range(0, SCREEN_W, 8):
            pyxel.line(x, GROUND_Y + 2, x + 4, GROUND_Y + 2, 11)
            pyxel.line(x, GROUND_Y + 6, x + 4, GROUND_Y + 6, 1)

        for x in range(0, SCREEN_W, 10):
            pyxel.line(x, CEILING_Y - 4, x + 3, CEILING_Y - 4, 6)

    def draw_player(self):
        for i, (tx, ty) in enumerate(self.player.trail):
            r = 1 if i < 5 else 2
            if self.player.mode == MODE_BALL:
                col = 9
            elif self.player.mode == MODE_CUBE:
                col = 13
            else:
                col = 12
            pyxel.circ(tx, ty, r, col)

        x = int(self.player.x)
        y = int(self.player.y)

        if self.player.mode == MODE_CUBE:
            body = 11 if self.player.gravity_dir == 1 else 12
            shade = 10 if self.player.gravity_dir == 1 else 6

            pyxel.rect(x, y, PLAYER_SIZE, PLAYER_SIZE, body)
            pyxel.rect(x + 1, y + 1, PLAYER_SIZE - 2, PLAYER_SIZE - 2, shade)
            pyxel.rectb(x, y, PLAYER_SIZE, PLAYER_SIZE, 7)

            phase = (self.player.rotation_frame // 5) % 4
            if phase == 0:
                pyxel.line(x + 2, y + 2, x + 7, y + 7, 7)
            elif phase == 1:
                pyxel.line(x + 7, y + 2, x + 2, y + 7, 7)
            elif phase == 2:
                pyxel.line(x + 2, y + 5, x + 7, y + 5, 7)
            else:
                pyxel.line(x + 5, y + 2, x + 5, y + 7, 7)

            pyxel.pset(x + 3, y + 3, 0)
            pyxel.pset(x + 6, y + 3, 0)

        elif self.player.mode == MODE_BALL:
            # Rotating ball with inner detail
            cx = x + PLAYER_SIZE // 2
            cy = y + PLAYER_SIZE // 2
            r = PLAYER_SIZE // 2

            body = 9 if self.player.gravity_dir == 1 else 14
            shade = 14 if self.player.gravity_dir == 1 else 9

            pyxel.circ(cx, cy, r, body)
            pyxel.circ(cx, cy, r - 2, shade)
            pyxel.circb(cx, cy, r, 7)

            # Rotating indicator line
            ang = math.radians(self.player.ball_angle)
            lx = int(cx + math.cos(ang) * (r - 2))
            ly = int(cy + math.sin(ang) * (r - 2))
            pyxel.line(cx, cy, lx, ly, 7)
            pyxel.pset(cx, cy, 0)

        else:
            # Wave
            pyxel.tri(x, y + PLAYER_SIZE, x + PLAYER_SIZE // 2, y, x + PLAYER_SIZE, y + PLAYER_SIZE, 12)
            pyxel.tri(x + 1, y + PLAYER_SIZE - 1, x + PLAYER_SIZE // 2, y + 2, x + PLAYER_SIZE - 1, y + PLAYER_SIZE - 1, 6)
            pyxel.line(x + 2, y + PLAYER_SIZE - 2, x + PLAYER_SIZE // 2, y + 3, 7)
            pyxel.line(x + PLAYER_SIZE - 2, y + PLAYER_SIZE - 2, x + PLAYER_SIZE // 2, y + 3, 7)

    def draw_progress(self):
        if self.level is None:
            return

        pyxel.rect(28, SCREEN_H - 10, 164, 5, 1)
        p = min(1.0, self.scroll_x / max(1, self.level.length))
        pyxel.rect(28, SCREEN_H - 10, int(164 * p), 5, 11)
        pyxel.rectb(28, SCREEN_H - 10, 164, 5, 7)

        pct = int(p * 100)
        pyxel.text(196, SCREEN_H - 11, f"{pct}%", 7)

    def draw_ui_play(self):
        if self.level is None:
            return

        mode_col = {MODE_CUBE: 10, MODE_WAVE: 12, MODE_BALL: 9}
        mc = mode_col.get(self.player.mode, 7)

        pyxel.text(4, 4, f"LEVEL: {self.level.name}", 7)
        pyxel.text(4, 11, f"SCORE: {self.score}", 7)
        pyxel.text(4, 18, f"BEST:  {self.best_scores[self.level_index]}", 7)
        pyxel.text(4, 25, f"MODE:  {self.player.mode.upper()}", mc)
        pyxel.text(4, 32, "SPACE=ACT  R=RESTART", 6)
        pyxel.text(4, 39, "ESC=MENU (no quit!)", 6)

    def draw_difficulty_badge(self, diff, x, y):
        colors = {"Easy": 11, "Normal": 10, "Hard": 9, "Extreme": 8}
        col = colors.get(diff, 7)
        pyxel.text(x, y, diff, col)

    def draw_center_box(self, title, subtitle, border_col):
        w = 170
        h = 40
        bx = (SCREEN_W - w) // 2
        by = 44
        pyxel.rect(bx, by, w, h, 0)
        pyxel.rectb(bx, by, w, h, border_col)
        tx = bx + (w - len(title) * 4) // 2
        pyxel.text(tx, by + 8, title, border_col)
        pyxel.text(bx + 6, by + 20, subtitle, 7)
        pyxel.text(bx + 6, by + 29, "ESC = MENU  (bleibt offen)", 6)

    def draw_menu(self):
        pyxel.cls(0)
        self.menu_anim = (self.menu_anim + 1) % 120

        # Animierter Hintergrund
        for i in range(0, SCREEN_W, 14):
            y_off = int(math.sin((i + self.menu_anim) * 0.12) * 4)
            pyxel.pset((i * 3 + self.menu_anim) % SCREEN_W, 14 + (i % 20) + y_off, 1)
            pyxel.pset((i * 5 + self.menu_anim // 2) % SCREEN_W, 30 + (i % 17) + y_off, 5)

        # Title
        pyxel.text(66, 10, "PIXEL DASH WAVE", 11)
        pyxel.text(58, 19, "Cube + Ball + Wave", 7)

        # Level selector box
        box_x, box_y, box_w, box_h = 28, 36, 164, 56
        pyxel.rect(box_x, box_y, box_w, box_h, 1)
        pyxel.rectb(box_x, box_y, box_w, box_h, 7)

        if self.level_templates:
            level = self.level_templates[self.level_index]

            # Arrows
            pyxel.text(box_x + 4, box_y + 14, "<", 7 if self.level_index > 0 else 1)
            pyxel.text(box_x + box_w - 8, box_y + 14, ">", 7 if self.level_index < len(self.level_templates) - 1 else 1)

            # Level name centered
            name_x = box_x + (box_w - len(level.name) * 4) // 2
            pyxel.text(name_x, box_y + 10, level.name, 10)

            # Difficulty badge
            diff_x = box_x + (box_w - len(level.difficulty) * 4) // 2
            self.draw_difficulty_badge(level.difficulty, diff_x, box_y + 22)

            # Best score
            pyxel.text(box_x + 8, box_y + 34, f"Best: {self.best_scores[self.level_index]}", 7)

            # Level dots (pagination)
            dot_start = box_x + (box_w - len(self.level_templates) * 6) // 2
            for i, _ in enumerate(self.level_templates):
                dc = 11 if i == self.level_index else 5
                pyxel.circ(dot_start + i * 6, box_y + 46, 2, dc)

        # Controls hint
        pyxel.text(22, 98, "LEFT/RIGHT = level  SPACE/ENTER = start", 6)

        # Mode icons preview
        mx = 28
        my = GROUND_Y - 10
        pyxel.rect(mx, my, 10, 10, 11)          # cube
        pyxel.rectb(mx, my, 10, 10, 7)
        pyxel.circ(mx + 30, my + 5, 5, 9)       # ball
        pyxel.circb(mx + 30, my + 5, 5, 7)
        pyxel.tri(mx + 55, my + 10, mx + 60, my, mx + 65, my + 10, 12)  # wave

        pyxel.text(mx, my - 7, "CUBE", 10)
        pyxel.text(mx + 22, my - 7, "BALL", 9)
        pyxel.text(mx + 50, my - 7, "WAVE", 12)

        pyxel.rect(0, GROUND_Y, SCREEN_W, SCREEN_H - GROUND_Y, 3)
        pyxel.line(0, GROUND_Y, SCREEN_W, GROUND_Y, 7)

    # -----------------------------------------------------
    # main draw
    # -----------------------------------------------------
    def draw(self):
        if self.state == STATE_MENU:
            self.draw_menu()
            return

        self.draw_background()

        if self.level is not None:
            for hazard in self.level.hazards:
                hazard.draw(self.scroll_x)
            for orb in self.level.orbs:
                orb.draw(self.scroll_x)
            for portal in self.level.portals:
                portal.draw(self.scroll_x)

        self.draw_player()
        self.draw_ui_play()
        self.draw_progress()

        if self.state == STATE_DEAD:
            self.draw_center_box("GAME OVER", "SPACE/R = Neustart", 8)
        elif self.state == STATE_WIN:
            self.draw_center_box("LEVEL CLEAR!", "SPACE/R = Nochmal", 11)

        if self.flash_timer > 0:
            pyxel.rect(0, 0, SCREEN_W, SCREEN_H, 7)


Game()