import pyxel
import random

# ---------------- CONFIG ----------------
TILE = 8
SCREEN_W = 160
SCREEN_H = 120

MAP_W = 40
MAP_H = 40

player_x = 16
player_y = 16
player_speed = 2

player_lives = 3
player_crystals = 0

player_dir = (1, 0)

shield_timer = 0
hit_timer = 0

level_transition_timer = 0
game_over_timer = 0

level = 1

exit_x = 0
exit_y = 0

ghosts = []
crystals = []
bullets = []

map_data = [[1]*MAP_W for _ in range(MAP_H)]

game_state = "title"

# ❄️ SNOW (nur Titelbild)
snow = []
for _ in range(40):
    snow.append([random.randint(0, SCREEN_W), random.randint(0, SCREEN_H)])

# ---------------- MAP ----------------
def generate_map():
    global player_x, player_y

    for y in range(MAP_H):
        for x in range(MAP_W):
            map_data[y][x] = 1

    sx = random.randint(1, MAP_W-2)
    sy = random.randint(1, MAP_H-2)

    map_data[sy][sx] = 0
    stack = [(sx, sy)]

    player_x = sx * TILE
    player_y = sy * TILE

    dirs = [(2,0),(-2,0),(0,2),(0,-2)]

    while stack:
        x, y = stack[-1]
        random.shuffle(dirs)

        carved = False
        for dx, dy in dirs:
            nx, ny = x + dx, y + dy

            if 1 <= nx < MAP_W-1 and 1 <= ny < MAP_H-1:
                if map_data[ny][nx] == 1:
                    map_data[ny][nx] = 0
                    map_data[y+dy//2][x+dx//2] = 0
                    stack.append((nx, ny))
                    carved = True
                    break

        if not carved:
            stack.pop()

# ---------------- HELPERS ----------------
def collide(x, y, size):
    for cx, cy in [(x,y),(x+size,y),(x,y+size),(x+size,y+size)]:
        tx = int(cx // TILE)
        ty = int(cy // TILE)

        if tx < 0 or ty < 0 or tx >= MAP_W or ty >= MAP_H:
            return True
        if map_data[ty][tx] == 1:
            return True

    return False

def random_free():
    while True:
        x = random.randint(2, MAP_W-2)*TILE
        y = random.randint(2, MAP_H-2)*TILE
        if not collide(x,y,6):
            return x,y

# ---------------- SPAWN ----------------
def spawn():
    global exit_x, exit_y

    ghosts.clear()
    crystals.clear()
    bullets.clear()

    for _ in range(15 + level*5):
        x,y = random_free()
        ghosts.append({
            "x":x,
            "y":y,
            "alive":True,
            "hit_timer":0
        })

    for _ in range(60 + level*15):
        x,y = random_free()
        crystals.append({"x":x,"y":y,"collected":False})

    exit_x, exit_y = random_free()

# ---------------- PLAYER ----------------
def move_axis(x,y,dx,dy):
    steps = int(max(abs(dx),abs(dy)))

    for _ in range(steps):
        if steps != 0:
            if not collide(x + dx/steps, y, 6):
                x += dx/steps
            if not collide(x, y + dy/steps, 6):
                y += dy/steps

    return x,y

def update_player():
    global player_x, player_y, player_dir

    dx = dy = 0

    if pyxel.btn(pyxel.KEY_LEFT):
        dx = -player_speed
        player_dir = (-1,0)
    if pyxel.btn(pyxel.KEY_RIGHT):
        dx = player_speed
        player_dir = (1,0)
    if pyxel.btn(pyxel.KEY_UP):
        dy = -player_speed
        player_dir = (0,-1)
    if pyxel.btn(pyxel.KEY_DOWN):
        dy = player_speed
        player_dir = (0,1)

    player_x, player_y = move_axis(player_x,player_y,dx,dy)

    if pyxel.btnp(pyxel.KEY_SPACE):
        shoot()

# ---------------- SHOOT ----------------
def shoot():
    global player_crystals

    if player_crystals <= 0:
        return

    bullets.append({
        "x": player_x+3,
        "y": player_y+3,
        "dx": player_dir[0]*3,
        "dy": player_dir[1]*3,
        "life": 30,
        "trail": []
    })

    player_crystals -= 1

# ---------------- GHOSTS ----------------
def update_ghosts():
    global player_lives, shield_timer, game_over_timer, hit_timer

    speed = 0.8 + level*0.15
    follow = min(1, 0.5 + level*0.1)

    for g in ghosts:
        if not g["alive"] and g["hit_timer"] <= 0:
            continue

        dx = player_x - g["x"]
        dy = player_y - g["y"]

        if g["alive"]:
            if random.random() < follow:
                dist = max(1, (dx**2 + dy**2)**0.5)
                dirx = dx / dist
                diry = dy / dist
            else:
                dirx, diry = random.choice([
                    (1,0),(-1,0),(0,1),(0,-1),
                    (1,1),(-1,-1),(1,-1),(-1,1)
                ])

            nx = g["x"] + dirx*speed
            ny = g["y"] + diry*speed

            if not collide(nx,ny,6):
                g["x"], g["y"] = nx, ny

            if abs(g["x"]-player_x)<6 and abs(g["y"]-player_y)<6:
                if shield_timer <= 0:
                    player_lives -= 1
                    shield_timer = 120
                    hit_timer = 20

                    if player_lives <= 0:
                        game_over_timer = 60

        if g["hit_timer"] > 0:
            g["hit_timer"] -= 1

# ---------------- BULLETS ----------------
def update_bullets():
    for b in bullets:
        b["x"] += b["dx"]
        b["y"] += b["dy"]
        b["life"] -= 1

        b["trail"].append((b["x"], b["y"]))
        if len(b["trail"]) > 5:
            b["trail"].pop(0)

        if collide(b["x"],b["y"],2) or b["life"] <= 0:
            b["life"] = 0

        for g in ghosts:
            if g["alive"] and abs(b["x"]-g["x"])<6 and abs(b["y"]-g["y"])<6:
                g["alive"] = False
                g["hit_timer"] = 20
                b["life"] = 0

    bullets[:] = [b for b in bullets if b["life"] > 0]

# ---------------- ITEMS ----------------
def collect():
    global player_crystals

    for c in crystals:
        if not c["collected"] and abs(player_x-c["x"])<6 and abs(player_y-c["y"])<6:
            c["collected"] = True
            player_crystals += 1

# ---------------- EXIT ----------------
def check_exit():
    global level, level_transition_timer

    if abs(player_x-exit_x)<6 and abs(player_y-exit_y)<6:
        level += 1
        level_transition_timer = 60
        next_level()

def next_level():
    generate_map()
    spawn()

# ---------------- DRAW ----------------
def draw():

    # ❄️ TITLE SCREEN (DETAILLIERTER PINGUIN + SCHNEE)
    if game_state == "title":
        pyxel.cls(0)

        # SNOW
        for s in snow:
            pyxel.pset(s[0], s[1], 7)
            s[1] += 1
            s[0] += random.randint(-1, 1)

            if s[1] > SCREEN_H:
                s[1] = 0
                s[0] = random.randint(0, SCREEN_W)

        cx = 80
        cy = 65

        pulse = pyxel.frame_count % 30
        r1 = 18 + (pulse % 3)
        r2 = 12 + (pulse % 2)
        r3 = 6

        # Portal
        pyxel.circ(cx, cy, r1, 12)
        pyxel.circ(cx, cy, r2, 11)
        pyxel.circ(cx, cy, r3, 7)

        # 🐧 PINGUIN (DETAILLIERT)
        # Körper
        pyxel.circ(cx, cy+10, 10, 0)
        # Bauch
        pyxel.circ(cx, cy+12, 6, 7)
        # Kopf
        pyxel.circ(cx, cy+2, 5, 0)
        # Augen
        pyxel.pset(cx-2, cy+1, 7)
        pyxel.pset(cx+2, cy+1, 7)
        # Schnabel
        pyxel.tri(cx-1, cy+3, cx+1, cy+3, cx, cy+5, 9)
        # Füße
        pyxel.tri(cx-4, cy+18, cx-1, cy+18, cx-3, cy+20, 10)
        pyxel.tri(cx+1, cy+18, cx+4, cy+18, cx+3, cy+20, 10)

        pyxel.text(45, 20, "FROSTY PENGUIN CAVE", 7)
        pyxel.text(35, 105, "PRESS SPACE TO START", 10)
        return

    if game_over_timer > 0:
        pyxel.cls(0)
        pyxel.text(55,55,"GAME OVER",8)
        return

    if level_transition_timer > 0:
        pyxel.cls(12)

        cx = 80
        cy = 60

        pyxel.circ(cx, cy, 20, 0)
        pyxel.circ(cx, cy+5, 12, 7)
        pyxel.circ(cx, cy-18, 10, 0)

        pyxel.pset(cx-3, cy-20, 7)
        pyxel.pset(cx+3, cy-20, 7)
        pyxel.pset(cx, cy-18, 9)

        pyxel.text(60, 20, f"LEVEL {level}", 7)
        return

    bg = max(0, 1 - level//2)
    pyxel.cls(bg)

    camx = int(player_x - SCREEN_W/2)
    camy = int(player_y - SCREEN_H/2)

    if hit_timer > 0:
        camx += random.randint(-2,2)
        camy += random.randint(-2,2)

    for y in range(MAP_H):
        for x in range(MAP_W):
            if map_data[y][x] == 1:
                px = x*TILE - camx
                py = y*TILE - camy
                pyxel.rect(px,py,8,8,12)
                pyxel.rect(px,py,8,2,7)
                pyxel.line(px,py+7,px+7,py,6)

    for c in crystals:
        if not c["collected"]:
            x = c["x"] - camx
            y = c["y"] - camy
            col = 11 if pyxel.frame_count % 20 < 10 else 7
            pyxel.tri(x+3,y, x,y+6, x+6,y+6,col)

    ex = exit_x - camx
    ey = exit_y - camy
    col = 9 if pyxel.frame_count % 20 < 10 else 10
    pyxel.circ(ex+4, ey+4, 4, col)
    pyxel.circ(ex+4, ey+4, 2, 7)

    for b in bullets:
        for tx, ty in b["trail"]:
            pyxel.pset(tx - camx, ty - camy, 7)
        pyxel.rect(b["x"] - camx, b["y"] - camy, 2, 2, 10)

    for g in ghosts:
        if not g["alive"] and g["hit_timer"] <= 0:
            continue

        x = int(g["x"] - camx)
        y = int(g["y"] - camy)

        shake = random.randint(-1,1) if g["hit_timer"] > 0 else 0
        col = 8 if g["hit_timer"] > 0 else 7

        pyxel.circ(x+3+shake,y+3,3,col)
        pyxel.rect(x+shake,y+3,6,3,col)
        pyxel.pset(x+2+shake,y+3,0)
        pyxel.pset(x+4+shake,y+3,0)

    px = player_x - camx
    py = player_y - camy

    pyxel.circ(px+3, py+3, 3, 0)
    pyxel.circ(px+3, py+4, 2, 7)
    pyxel.pset(px+2, py+2, 7)
    pyxel.pset(px+4, py+2, 7)
    pyxel.pset(px+3, py+3, 9)

    for i in range(player_lives):
        x = 5 + i*10
        y = 5
        pyxel.circ(x+2,y+2,2,8)
        pyxel.circ(x+5,y+2,2,8)
        pyxel.tri(x,y+3, x+7,y+3, x+3,y+7,8)

    pyxel.text(5,20,f"Crystals:{player_crystals}",7)
    pyxel.text(5,30,f"Level:{level}",10)

# ---------------- UPDATE ----------------
def update():
    global shield_timer, level_transition_timer, game_over_timer, hit_timer, game_state

    if game_state == "title":
        if pyxel.btnp(pyxel.KEY_SPACE):
            game_state = "play"
        return

    if game_over_timer > 0:
        game_over_timer -= 1
        if game_over_timer == 0:
            reset_game()
        return

    if level_transition_timer > 0:
        level_transition_timer -= 1
        return

    update_player()
    update_ghosts()
    update_bullets()
    collect()
    check_exit()

    if shield_timer > 0:
        shield_timer -= 1

    if hit_timer > 0:
        hit_timer -= 1

# ---------------- RESET ----------------
def reset_game():
    global player_lives, player_crystals, level
    player_lives = 3
    player_crystals = 0
    level = 1
    generate_map()
    spawn()

# ---------------- START ----------------
generate_map()
pyxel.init(SCREEN_W, SCREEN_H, title="Frosty Penguin Cave ❄️")
spawn()
pyxel.run(update, draw)