import math
import random
import pyxel
import json
import os

# ============================================================
# SUPREME LEADER SURVIVAL
# PART 1 / 2
# ============================================================

SCREEN_W = 320
SCREEN_H = 200
FPS = 60

GROUND_Y = 168
PLAYER_W = 16
PLAYER_H = 24

GRAVITY = 0.28
MAX_FALL = 7.2

LEVEL_DISTANCE = 2400
CAMPAIGN_LEVELS = 10

SAVE_FILE = "save.json"

# ------------------------------------------------------------
# COLORS
# ------------------------------------------------------------
C0 = 0
C1 = 1
C2 = 2
C3 = 3
C4 = 4
C5 = 5
C6 = 6
C7 = 7
C8 = 8
C9 = 9
C10 = 10
C11 = 11
C12 = 12
C13 = 13
C14 = 14
C15 = 15

# ------------------------------------------------------------
# HELPERS
# ------------------------------------------------------------
def clamp(v, lo, hi):
    return max(lo, min(hi, v))


def rects_overlap(ax, ay, aw, ah, bx, by, bw, bh):
    return ax < bx + bw and ax + aw > bx and ay < by + bh and ay + ah > by


def dist2(ax, ay, bx, by):
    dx = ax - bx
    dy = ay - by
    return dx * dx + dy * dy


def sign(v):
    if v < 0:
        return -1
    if v > 0:
        return 1
    return 0


# ------------------------------------------------------------
# COUNTRIES
# ------------------------------------------------------------
COUNTRIES = [
    {"key": "usa", "name": "USA"},
    {"key": "mexico", "name": "MEXICO"},
    {"key": "china", "name": "CHINA"},
    {"key": "germany", "name": "GERMANY"},
    {"key": "serbia", "name": "SERBIA"},
    {"key": "iran", "name": "IRAN"},
    {"key": "israel", "name": "ISRAEL"},
    {"key": "antarctica", "name": "ANTARCTICA"},
    {"key": "italy", "name": "ITALY"},
    {"key": "russia", "name": "RUSSIA"},
    {"key": "poland", "name": "POLAND"},
    {"key": "france", "name": "FRANCE"},
    {"key": "nigeria", "name": "NIGERIA"},
    {"key": "switzerland", "name": "SWITZERLAND"},
    {"key": "canada", "name": "CANADA"},
    {"key": "ukraine", "name": "UKRAINE"},
    {"key": "turkiye", "name": "TURKEY"},
    {"key": "south_korea", "name": "SOUTH KOREA"},
    {"key": "north_korea", "name": "NORTH KOREA"},
    {"key": "saudi", "name": "SAUDI ARABIA"},
    {"key": "belgium", "name": "BELGIUM"},
    {"key": "austria", "name": "AUSTRIA"},
    {"key": "romania", "name": "ROMANIA"},
    {"key": "japan", "name": "JAPAN"},
    {"key": "uk", "name": "UK"},
    {"key": "spain", "name": "SPAIN"},
    {"key": "brazil", "name": "BRASIL"},
    {"key": "australia", "name": "AUSTRALIA"},
    {"key": "south_africa", "name": "SOUTH AFRICA"},
]

SKIN_PRICES = [0, 40, 70, 115, 180, 270, 390, 550]

# ------------------------------------------------------------
# More detailed stylized country outlines
# ------------------------------------------------------------
COUNTRY_OUTLINES = {
    "usa": [(0,8),(3,6),(7,5),(11,5),(15,4),(20,5),(24,4),(29,5),(34,6),(38,8),(35,11),(28,11),(23,12),(18,11),(12,12),(7,11),(3,10)],
    "mexico": [(4,5),(8,4),(12,5),(16,7),(19,10),(20,14),(17,18),(14,17),(12,21),(10,20),(9,16),(6,14),(4,10)],
    "china": [(3,6),(8,3),(13,4),(18,3),(24,4),(29,6),(33,9),(31,13),(25,14),(21,13),(16,15),(10,14),(5,11)],
    "germany": [(8,3),(12,2),(16,4),(18,8),(17,13),(14,17),(10,18),(7,14),(6,9)],
    "serbia": [(8,4),(12,4),(15,7),(14,11),(12,15),(8,16),(5,12),(5,8)],
    "iran": [(4,7),(9,4),(15,4),(21,4),(27,6),(31,9),(29,13),(23,15),(15,15),(9,14),(5,11)],
    "israel": [(9,3),(13,3),(15,6),(15,12),(13,16),(10,16),(8,12),(8,6)],
    "antarctica": [(4,18),(10,13),(16,11),(23,10),(29,11),(36,13),(40,17),(35,20),(25,21),(14,21),(7,20)],
    "india": [(13,3),(17,4),(20,7),(22,11),(20,15),(17,18),(14,22),(11,18),(9,13),(8,8)],
    "italy": [(12,3),(16,3),(17,7),(15,11),(18,15),(16,19),(13,22),(11,18),(12,13),(9,9)],
    "russia": [(2,7),(9,4),(17,4),(24,4),(31,5),(39,6),(44,8),(42,12),(36,13),(29,12),(23,14),(16,14),(9,12),(4,11)],
    "poland": [(7,4),(13,4),(17,7),(17,11),(13,15),(8,15),(5,10)],
    "france": [(7,4),(12,3),(16,5),(17,9),(15,13),(11,16),(7,14),(4,10),(5,6)],
    "nigeria": [(7,4),(13,4),(15,8),(15,13),(11,16),(7,15),(5,10)],
    "switzerland": [(8,4),(13,4),(16,8),(15,13),(10,15),(6,12),(5,8)],
    "canada": [(5,6),(8,3),(12,5),(16,3),(20,5),(24,3),(29,5),(33,9),(31,14),(27,18),(20,19),(13,18),(8,15)],
    "ukraine": [(5,6),(11,5),(17,5),(23,6),(27,9),(24,13),(18,14),(10,14),(5,11)],
    "turkiye": [(5,6),(11,5),(16,5),(22,6),(26,9),(25,12),(20,14),(12,14),(7,12),(4,9)],
    "south_korea": [(8,4),(13,4),(17,7),(17,11),(14,15),(9,15),(5,11),(5,7)],
    "north_korea": [(6,5),(12,5),(17,6),(18,10),(16,14),(11,15),(6,13),(4,9)],
    "saudi": [(5,7),(12,5),(19,5),(27,6),(31,9),(31,13),(25,16),(17,16),(10,15),(6,12)],
    "belgium": [(7,4),(12,4),(15,6),(15,11),(12,15),(8,15),(5,11),(5,7)],
    "austria": [(5,7),(11,5),(18,5),(25,6),(29,8),(27,12),(21,14),(13,15),(6,12)],
    "romania": [(6,5),(11,4),(16,5),(18,9),(17,13),(12,16),(7,15),(5,11)],
    "japan": [(12,3),(15,5),(16,9),(15,13),(12,16),(9,13),(8,8),(9,5)],
    "uk": [(7,4),(11,3),(15,5),(16,10),(14,14),(10,17),(6,15),(4,10),(5,6)],
    "spain": [(4,7),(10,5),(17,5),(23,6),(27,9),(24,13),(17,15),(9,14),(4,11)],
    "brazil": [(7,4),(13,4),(19,6),(22,10),(21,15),(17,19),(10,18),(6,13),(5,8)],
    "australia": [(6,8),(12,5),(19,5),(25,7),(29,11),(27,16),(21,18),(14,18),(8,15),(5,11)],
    "south_africa": [(5,7),(10,5),(16,5),(21,7),(24,11),(21,15),(14,16),(8,14),(4,10)],
}


def draw_country_outline(key, x, y, scale=3, color=C7, fill=False):
    pts = COUNTRY_OUTLINES.get(key)
    if not pts:
        return

    scaled = [(x + px * scale, y + py * scale) for px, py in pts]

    if fill:
        cx = sum(px for px, _ in scaled) // len(scaled)
        cy = sum(py for _, py in scaled) // len(scaled)
        for i in range(len(scaled)):
            x1, y1 = scaled[i]
            x2, y2 = scaled[(i + 1) % len(scaled)]
            pyxel.tri(cx, cy, x1, y1, x2, y2, color)

    for i in range(len(scaled)):
        x1, y1 = scaled[i]
        x2, y2 = scaled[(i + 1) % len(scaled)]
        pyxel.line(x1, y1, x2, y2, color)


# ------------------------------------------------------------
# SAVE
# ------------------------------------------------------------
def load_save():
    data = {
        "best_survival": 0,
        "coins_total": 0,
        "unlocked_skins": [0],
        "selected_country": 0,
        "extra_hp_bonus": 0,
    }
    try:
        if os.path.exists(SAVE_FILE):
            with open(SAVE_FILE, "r", encoding="utf-8") as f:
                raw = json.load(f)
            data["best_survival"] = int(raw.get("best_survival", 0))
            data["coins_total"] = int(raw.get("coins_total", 0))
            data["unlocked_skins"] = list(raw.get("unlocked_skins", [0]))
            data["selected_country"] = int(raw.get("selected_country", 0))
            data["extra_hp_bonus"] = int(raw.get("extra_hp_bonus", 0))
            if 0 not in data["unlocked_skins"]:
                data["unlocked_skins"].append(0)
    except Exception:
        pass
    return data


def save_save(data):
    try:
        with open(SAVE_FILE, "w", encoding="utf-8") as f:
            json.dump(data, f)
    except Exception:
        pass


# ------------------------------------------------------------
# AUDIO
# ------------------------------------------------------------
def setup_audio():
    try:
        pyxel.sound(0).set("c3g3c4", "p", "764", "n", 18)    # shot
        pyxel.sound(1).set("c2a1f1", "n", "7642", "f", 20)  # hit
        pyxel.sound(2).set("f2c2", "n", "77", "f", 24)      # boom
        pyxel.sound(3).set("c3e3g3", "s", "753", "n", 14)   # pickup
        pyxel.sound(4).set("c4d4", "p", "75", "n", 20)      # switch
        pyxel.sound(5).set("g2a2", "p", "76", "n", 14)      # rocket
        pyxel.sound(6).set("c4", "s", "7", "n", 8)          # warn
        pyxel.sound(7).set("c2g1", "n", "77", "n", 28)      # boss
        pyxel.sound(8).set("c3e3g3c4 g3e3c3g2", "s", "6543", "vvvv", 10)  # story theme
    except Exception:
        pass


def sfx(ch, snd):
    try:
        pyxel.play(ch, snd)
    except Exception:
        pass


def play_story_music():
    try:
        pyxel.play(3, 8, loop=True)
    except Exception:
        pass


def stop_story_music():
    try:
        pyxel.stop(3)
    except Exception:
        pass


# ------------------------------------------------------------
# FLAGS
# ------------------------------------------------------------
def draw_flag(key, x, y, w=16, h=10, wave=0):
    ox = int(math.sin(wave) * 1.0)

    def rx(v):
        return x + v + ox

    pyxel.rectb(rx(0), y, w, h, C7)

    if key == "usa":
        for i in range(h):
            pyxel.line(rx(0), y + i, rx(w - 1), y + i, C8 if i % 2 == 0 else C7)
        pyxel.rect(rx(0), y, w // 2, h // 2 + 1, C1)

    elif key == "mexico":
        pyxel.rect(rx(0), y, w // 3, h, C3)
        pyxel.rect(rx(w // 3), y, w // 3, h, C7)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C8)
        pyxel.pset(rx(w // 2), y + h // 2, C4)

    elif key == "china":
        pyxel.rect(rx(0), y, w, h, C8)
        pyxel.pset(rx(3), y + 2, C10)
        pyxel.pset(rx(2), y + 3, C10)
        pyxel.pset(rx(3), y + 3, C10)
        pyxel.pset(rx(4), y + 3, C10)
        pyxel.pset(rx(3), y + 4, C10)

    elif key == "germany":
        pyxel.rect(rx(0), y, w, h // 3, C0)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C8)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C10)

    elif key == "serbia":
        pyxel.rect(rx(0), y, w, h // 3, C8)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C1)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C7)
        pyxel.pset(rx(4), y + h // 2, C10)

    elif key == "iran":
        pyxel.rect(rx(0), y, w, h // 3, C3)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C7)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C8)

    elif key == "israel":
        pyxel.rect(rx(0), y, w, h, C7)
        pyxel.rect(rx(0), y + 1, w, 1, C12)
        pyxel.rect(rx(0), y + h - 2, w, 1, C12)
        pyxel.pset(rx(w // 2), y + h // 2, C12)

    elif key == "antarctica":
        pyxel.rect(rx(0), y, w, h, C12)
        pyxel.tri(rx(w // 2), y + 1, rx(4), y + h - 2, rx(w - 4), y + h - 2, C7)

    elif key == "india":
        pyxel.rect(rx(0), y, w, h // 3, C10)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C7)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C3)
        pyxel.pset(rx(w // 2), y + h // 2, C12)

    elif key == "italy":
        pyxel.rect(rx(0), y, w // 3, h, C3)
        pyxel.rect(rx(w // 3), y, w // 3, h, C7)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C8)

    elif key == "russia":
        pyxel.rect(rx(0), y, w, h // 3, C7)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C12)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C8)

    elif key == "poland":
        pyxel.rect(rx(0), y, w, h // 2, C7)
        pyxel.rect(rx(0), y + h // 2, w, h - h // 2, C8)

    elif key == "france":
        pyxel.rect(rx(0), y, w // 3, h, C1)
        pyxel.rect(rx(w // 3), y, w // 3, h, C7)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C8)

    elif key == "nigeria":
        pyxel.rect(rx(0), y, w // 3, h, C3)
        pyxel.rect(rx(w // 3), y, w // 3, h, C7)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C3)

    elif key == "switzerland":
        pyxel.rect(rx(0), y, w, h, C8)
        pyxel.rect(rx(w // 2 - 1), y + 2, 2, h - 4, C7)
        pyxel.rect(rx(3), y + h // 2 - 1, w - 6, 2, C7)

    elif key == "canada":
        pyxel.rect(rx(0), y, w // 4, h, C8)
        pyxel.rect(rx(w // 4), y, w // 2, h, C7)
        pyxel.rect(rx(3 * w // 4), y, w - 3 * w // 4, h, C8)
        pyxel.pset(rx(w // 2), y + h // 2, C8)

    elif key == "ukraine":
        pyxel.rect(rx(0), y, w, h // 2, C12)
        pyxel.rect(rx(0), y + h // 2, w, h - h // 2, C10)

    elif key == "turkiye":
        pyxel.rect(rx(0), y, w, h, C8)
        pyxel.circ(rx(6), y + h // 2, 2, C7)
        pyxel.circ(rx(7), y + h // 2, 2, C8)
        pyxel.pset(rx(11), y + h // 2, C7)

    elif key == "south_korea":
        pyxel.rect(rx(0), y, w, h, C7)
        pyxel.pset(rx(w // 2 - 1), y + h // 2, C8)
        pyxel.pset(rx(w // 2), y + h // 2, C12)

    elif key == "north_korea":
        pyxel.rect(rx(0), y, w, h, C8)
        pyxel.rect(rx(0), y + 1, w, 1, C12)
        pyxel.rect(rx(0), y + h - 2, w, 1, C12)
        pyxel.pset(rx(w // 2), y + h // 2, C7)

    elif key == "saudi":
        pyxel.rect(rx(0), y, w, h, C3)
        pyxel.line(rx(3), y + h // 2, rx(w - 4), y + h // 2, C7)

    elif key == "belgium":
        pyxel.rect(rx(0), y, w // 3, h, C0)
        pyxel.rect(rx(w // 3), y, w // 3, h, C10)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C8)

    elif key == "austria":
        pyxel.rect(rx(0), y, w, h // 3, C8)
        pyxel.rect(rx(0), y + h // 3, w, h // 3, C7)
        pyxel.rect(rx(0), y + 2 * (h // 3), w, h - 2 * (h // 3), C8)

    elif key == "romania":
        pyxel.rect(rx(0), y, w // 3, h, C1)
        pyxel.rect(rx(w // 3), y, w // 3, h, C10)
        pyxel.rect(rx(2 * w // 3), y, w - 2 * (w // 3), h, C8)

    elif key == "japan":
        pyxel.rect(rx(0), y, w, h, C7)
        pyxel.circ(rx(w // 2), y + h // 2, 3, C8)

    elif key == "uk":
        pyxel.rect(rx(0), y, w, h, C1)
        pyxel.line(rx(0), y, rx(w - 1), y + h - 1, C7)
        pyxel.line(rx(w - 1), y, rx(0), y + h - 1, C7)
        pyxel.rect(rx(w // 2 - 1), y, 2, h, C8)
        pyxel.rect(rx(0), y + h // 2 - 1, w, 2, C8)

    elif key == "spain":
        pyxel.rect(rx(0), y, w, h // 4, C8)
        pyxel.rect(rx(0), y + h // 4, w, h // 2, C10)
        pyxel.rect(rx(0), y + 3 * h // 4, w, h - 3 * h // 4, C8)

    elif key == "brazil":
        pyxel.rect(rx(0), y, w, h, C3)
        pyxel.tri(rx(w // 2), y + 1, rx(2), y + h // 2, rx(w // 2), y + h - 2, C10)
        pyxel.tri(rx(w // 2), y + 1, rx(w - 2), y + h // 2, rx(w // 2), y + h - 2, C10)
        pyxel.pset(rx(w // 2), y + h // 2, C12)

    elif key == "australia":
        pyxel.rect(rx(0), y, w, h, C1)
        pyxel.rect(rx(0), y, w // 2, h // 2, C7)
        pyxel.pset(rx(w - 4), y + h // 2, C7)

    elif key == "south_africa":
        pyxel.rect(rx(0), y, w, h, C3)
        pyxel.line(rx(0), y + h // 2, rx(w - 1), y + h // 2, C10)
        pyxel.line(rx(0), y + h // 2 - 1, rx(w - 1), y + h // 2 - 1, C0)

    else:
        pyxel.rect(rx(0), y, w, h, C7)


# ------------------------------------------------------------
# TERRAIN
# ------------------------------------------------------------
class Terrain:
    def __init__(self):
        self.craters = []

    def reset(self):
        self.craters.clear()

    def band(self, difficulty):
        if difficulty < 2.5:
            return 0
        if difficulty < 5:
            return 1
        if difficulty < 8:
            return 2
        return 3

    def base_ground_y_at(self, x, difficulty):
        b = self.band(difficulty)
        if b == 0:
            return GROUND_Y
        if b == 1:
            return GROUND_Y - int(3 * math.sin(x * 0.018))
        if b == 2:
            return GROUND_Y - int(5 * math.sin(x * 0.020)) - int(2 * math.sin(x * 0.05))
        y = GROUND_Y - int(7 * math.sin(x * 0.021)) - int(3 * math.sin(x * 0.058))
        seg = int(x // 200) % 6
        if seg == 2:
            y += 4
        elif seg == 4:
            y -= 3
        return y

    def scripted_gap_at(self, x, difficulty):
        if difficulty < 7:
            return False
        seg = int(x // 240) % 8
        local = int(x % 240)
        return seg in (3, 6) and 84 <= local <= 116

    def crater_shape(self, x):
        max_depth = 0
        gap = False
        for c in self.craters:
            dx = abs(x - c["x"])
            r = c["r"]
            if dx < r:
                t = dx / r
                depth = int((1 - t * t) * c["depth"])
                max_depth = max(max_depth, depth)
            if dx < c["gap_r"]:
                gap = True
        return max_depth, gap

    def ground_y_at(self, x, difficulty):
        return self.base_ground_y_at(x, difficulty) + self.crater_shape(x)[0]

    def is_gap_at(self, x, difficulty):
        return self.scripted_gap_at(x, difficulty) or self.crater_shape(x)[1]

    def add_crater(self, x, radius):
        self.craters.append({
            "x": x,
            "r": radius,
            "gap_r": radius * 0.48,
            "depth": int(radius * 0.75)
        })
        if len(self.craters) > 28:
            self.craters.pop(0)

    def draw(self, cam_x, difficulty):
        for sx in range(SCREEN_W):
            wx = cam_x + sx
            base = self.base_ground_y_at(wx, difficulty)
            gy = self.ground_y_at(wx, difficulty)
            is_gap = self.is_gap_at(wx, difficulty)

            if is_gap:
                pyxel.line(sx, base, sx, gy - 1, C0)
                pyxel.line(sx, gy, sx, SCREEN_H, C4)
            else:
                pyxel.line(sx, gy, sx, SCREEN_H, C3)

        for sx in range(0, SCREEN_W, 2):
            wx = cam_x + sx
            gy = self.ground_y_at(wx, difficulty)
            if not self.is_gap_at(wx, difficulty):
                pat = (sx + int(cam_x * 0.7)) % 18
                if pat < 6:
                    pyxel.pset(sx, gy + 1, C13)
                    pyxel.pset(sx, gy + 3, C4)
                elif pat < 12:
                    pyxel.line(sx, gy + 5, sx + 1, gy + 5, C5)
                else:
                    pyxel.pset(sx, gy + 2, C4)

        for c in self.craters:
            sx = c["x"] - cam_x
            r = c["r"]
            if -r - 20 <= sx <= SCREEN_W + r + 20:
                rim_y = self.base_ground_y_at(c["x"], difficulty)
                pyxel.line(int(sx - r), rim_y, int(sx - r + 10), rim_y - 3, C4)
                pyxel.line(int(sx + r - 10), rim_y - 3, int(sx + r), rim_y, C4)


# ------------------------------------------------------------
# FX
# ------------------------------------------------------------
class Particle:
    def __init__(self, x, y, vx, vy, color, life, size=1, gravity=0.0):
        self.x = x
        self.y = y
        self.vx = vx
        self.vy = vy
        self.color = color
        self.life = life
        self.size = size
        self.gravity = gravity
        self.alive = True

    def update(self):
        self.life -= 1
        if self.life <= 0:
            self.alive = False
            return
        self.vy += self.gravity
        self.x += self.vx
        self.y += self.vy
        self.vx *= 0.99
        self.vy *= 0.99

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        if sx < -8 or sx > SCREEN_W + 8 or sy < -8 or sy > SCREEN_H + 8:
            return
        if self.size <= 1:
            pyxel.pset(sx, sy, self.color)
        else:
            pyxel.circ(sx, sy, self.size, self.color)


class Explosion:
    def __init__(self, x, y, power=12, kind="normal"):
        self.x = x
        self.y = y
        self.power = power
        self.kind = kind
        self.timer = 16 if kind != "nuke" else 28
        self.max_timer = self.timer
        self.alive = True

    def update(self):
        self.timer -= 1
        if self.timer <= 0:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        if sx < -100 or sx > SCREEN_W + 100:
            return

        t = 1.0 - self.timer / self.max_timer
        r = max(1, int(self.power * t))

        if self.kind == "nuke":
            pyxel.circ(sx, sy, r + 14, C9)
            pyxel.circ(sx, sy, r + 8, C10)
            pyxel.circ(sx, sy, r + 3, C7)
            pyxel.circ(sx, sy - r - 10, max(3, r // 2 + 5), C10)
            pyxel.circ(sx - 12, sy - r - 5, max(2, r // 2 + 2), C9)
            pyxel.circ(sx + 12, sy - r - 5, max(2, r // 2 + 2), C9)
            pyxel.line(sx, sy + r, sx, sy - r - 8, C7)
        elif self.kind == "muzzle":
            pyxel.circ(sx, sy, r + 2, C9)
            pyxel.circ(sx, sy, r, C7)
        else:
            pyxel.circ(sx, sy, r + 4, C8)
            pyxel.circ(sx, sy, r + 2, C9)
            pyxel.circ(sx, sy, r, C10)


class FlashText:
    def __init__(self, x, y, text, color=C7, life=40):
        self.x = x
        self.y = y
        self.text = text
        self.color = color
        self.life = life
        self.alive = True

    def update(self):
        self.life -= 1
        self.y -= 0.35
        if self.life <= 0:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        if -80 <= sx <= SCREEN_W + 80:
            pyxel.text(sx, int(self.y), self.text, self.color)


# ------------------------------------------------------------
# PROJECTILES
# ------------------------------------------------------------
class Bullet:
    def __init__(self, x, y, vx, vy, dmg=1, color=C7, radius=2, friendly=True, life=90):
        self.x = x
        self.y = y
        self.px = x
        self.py = y
        self.vx = vx
        self.vy = vy
        self.dmg = dmg
        self.color = color
        self.radius = radius
        self.friendly = friendly
        self.life = life
        self.alive = True

    def update(self):
        self.px = self.x
        self.py = self.y
        self.x += self.vx
        self.y += self.vy
        self.life -= 1
        if self.life <= 0 or self.x < -200 or self.x > 999999 or self.y < -80 or self.y > SCREEN_H + 80:
            self.alive = False

    def draw(self, cam_x):
        sx1 = int(self.px - cam_x)
        sx2 = int(self.x - cam_x)
        pyxel.line(sx1, int(self.py), sx2, int(self.y), self.color)
        pyxel.circ(sx2, int(self.y), self.radius, self.color)
        pyxel.pset(sx2, int(self.y), C7)


class Rocket:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.vx = 3.2
        self.vy = -0.3
        self.speed = 3.8
        self.turn = 0.08
        self.life = 130
        self.alive = True

    def update(self, enemies):
        self.life -= 1
        if self.life <= 0:
            self.alive = False
            return

        best = None
        best_d = 10**12
        for e in enemies:
            if not e.alive:
                continue
            cx = e.x + e.w / 2
            cy = e.y + e.h / 2
            d = dist2(self.x, self.y, cx, cy)
            if d < best_d:
                best_d = d
                best = e

        if best:
            tx = (best.x + best.w / 2) - self.x
            ty = (best.y + best.h / 2) - self.y
            ta = math.atan2(ty, tx)
            ca = math.atan2(self.vy, self.vx)

            diff = ta - ca
            while diff > math.pi:
                diff -= math.tau
            while diff < -math.pi:
                diff += math.tau

            ca += clamp(diff, -self.turn, self.turn)
            self.vx = math.cos(ca) * self.speed
            self.vy = math.sin(ca) * self.speed

        self.x += self.vx
        self.y += self.vy

        if self.x < -100 or self.y < -100 or self.y > SCREEN_H + 100:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        ang = math.atan2(self.vy, self.vx)
        dx = int(math.cos(ang) * 5)
        dy = int(math.sin(ang) * 5)
        pyxel.line(sx - dx, sy - dy, sx, sy, C9)
        pyxel.line(sx, sy, sx + dx, sy + dy, C10)
        pyxel.pset(sx + dx, sy + dy, C7)


class BombProjectile:
    def __init__(self, x, y, vx, vy, power=22):
        self.x = x
        self.y = y
        self.vx = vx
        self.vy = vy
        self.power = power
        self.alive = True

    def update(self):
        self.vy += 0.16
        self.x += self.vx
        self.y += self.vy
        if self.y > SCREEN_H + 120 or self.x < -100:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        pyxel.circ(sx, sy, 4, C8)
        pyxel.circ(sx, sy, 2, C9)
        pyxel.pset(sx, sy - 4, C7)


# ------------------------------------------------------------
# ENEMIES
# ------------------------------------------------------------
class Enemy:
    def __init__(self, x, y, w, h, hp, score):
        self.x = x
        self.y = y
        self.w = w
        self.h = h
        self.hp = hp
        self.max_hp = hp
        self.score = score
        self.flash = 0
        self.alive = True

    def hurt(self, dmg):
        self.hp -= dmg
        self.flash = 4
        if self.hp <= 0:
            self.alive = False
            return True
        return False

    def update_flash(self):
        if self.flash > 0:
            self.flash -= 1


class MissileEnemy(Enemy):
    def __init__(self, x, speed):
        super().__init__(x, -20, 8, 18, 1, 12)
        self.speed = speed
        self.anim = random.randint(0, 100)

    def update(self, game):
        self.anim += 1
        self.y += self.speed
        self.update_flash()

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        body = C7 if self.flash > 0 else C9
        pyxel.rect(sx, int(self.y), self.w, self.h, body)
        pyxel.rect(sx + 2, int(self.y) - 4, 4, 4, C7)
        pyxel.line(sx + 1, int(self.y) + self.h, sx + 4, int(self.y) + self.h + 4, C8)
        pyxel.line(sx + 7, int(self.y) + self.h, sx + 4, int(self.y) + self.h + 4, C8)
        if self.anim % 4 < 2:
            pyxel.pset(sx + 4, int(self.y) + self.h + 5, C10)


class DroneEnemy(Enemy):
    def __init__(self, x, y, tx, ty, speed):
        super().__init__(x, y, 18, 10, 2, 18)
        dx = tx - x
        dy = ty - y
        l = math.sqrt(dx * dx + dy * dy)
        if l == 0:
            l = 1
        self.vx = dx / l * speed
        self.vy = dy / l * speed
        self.wave = random.random() * math.tau

    def update(self, game):
        self.wave += 0.15
        self.x += self.vx
        self.y += self.vy + math.sin(self.wave) * 0.5
        self.update_flash()
        if self.x < game.cam_x - 200 or self.y < -80 or self.y > SCREEN_H + 80:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        body = C7 if self.flash > 0 else C15

        pyxel.rect(sx - 1, sy - 1, self.w + 2, self.h + 2, C0)
        pyxel.line(sx - 4, sy + 5, sx, sy + 4, C6)
        pyxel.line(sx + self.w, sy + 4, sx + self.w + 4, sy + 5, C6)
        pyxel.rect(sx, sy, self.w, self.h, body)
        pyxel.rect(sx + 2, sy + 1, self.w - 4, self.h - 3, C1)
        pyxel.rect(sx + 4, sy + 2, 10, 4, C12)
        pyxel.pset(sx + 6, sy + 3, C7)
        pyxel.pset(sx + 11, sy + 3, C7)
        pyxel.pset(sx + 5, sy - 1, C7)
        pyxel.pset(sx + 12, sy - 1, C7)
        pyxel.pset(sx + 4, sy + self.h, C9)
        pyxel.pset(sx + 13, sy + self.h, C9)


class TurretEnemy(Enemy):
    def __init__(self, x, y):
        super().__init__(x, y, 18, 16, 4, 28)
        self.cooldown = random.randint(25, 65)
        self.barrel_a = 0

    def update(self, game):
        px = game.player.x + 8
        py = game.player.y + 10
        self.barrel_a = math.atan2(py - (self.y - 6), px - (self.x + 9))
        self.cooldown -= 1
        self.update_flash()

        if self.cooldown <= 0:
            self.cooldown = random.randint(60, 95)
            sp = 2.6
            game.enemy_bullets.append(Bullet(
                self.x + 9, self.y - 5,
                math.cos(self.barrel_a) * sp,
                math.sin(self.barrel_a) * sp,
                dmg=1, color=C8, radius=2, friendly=False, life=120
            ))

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        body = C7 if self.flash > 0 else C4
        pyxel.rect(sx, int(self.y) - 12, 18, 12, body)
        pyxel.rect(sx + 4, int(self.y) - 16, 10, 6, C5)
        bx = sx + 9 + int(math.cos(self.barrel_a) * 10)
        by = int(self.y) - 13 + int(math.sin(self.barrel_a) * 10)
        pyxel.line(sx + 9, int(self.y) - 13, bx, by, C7)
        pyxel.rect(sx - 2, int(self.y), 22, 3, C5)
        pyxel.pset(sx + 7, int(self.y) - 8, C7)
        pyxel.pset(sx + 11, int(self.y) - 8, C7)


class HeavyBombEnemy(Enemy):
    def __init__(self, x, y, speed):
        super().__init__(x, y, 18, 32, 4, 40)
        self.speed = speed

    def update(self, game):
        self.y += self.speed
        self.update_flash()

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        body = C7 if self.flash > 0 else C8
        pyxel.rect(sx, int(self.y), self.w, self.h, body)
        pyxel.rect(sx + 4, int(self.y) - 5, 10, 5, C7)
        pyxel.line(sx + 2, int(self.y) + self.h, sx + 9, int(self.y) + self.h + 6, C9)
        pyxel.line(sx + 16, int(self.y) + self.h, sx + 9, int(self.y) + self.h + 6, C9)
        pyxel.rect(sx + 6, int(self.y) + 10, 6, 8, C10)


class BomberEnemy(Enemy):
    def __init__(self, x, y, direction, drop_x):
        super().__init__(x, y, 154, 36, 12, 120)
        self.dir = direction
        self.vx = 1.25 * direction
        self.drop_x = drop_x
        self.drop_cd = 120
        self.has_dropped = False
        self.anim = 0
        self.warning_phase = False

    def update(self, game):
        self.anim += 1
        self.x += self.vx
        self.update_flash()

        if not self.has_dropped and abs((self.x + self.w // 2) - self.drop_x) < 90:
            self.warning_phase = True
            self.drop_cd -= 1
            if self.drop_cd <= 0:
                self.has_dropped = True
                game.enemies.append(HeavyBombEnemy(self.x + self.w // 2 - 9, self.y + 24, 2.0))

        if self.x < game.cam_x - 300 or self.x > game.cam_x + SCREEN_W + 300:
            self.alive = False

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)
        body = C7 if self.flash > 0 else C12

        pyxel.rect(sx + 28, sy + 12, 100, 12, body)
        pyxel.rect(sx + 20, sy + 14, 116, 8, C6)

        if self.dir == 1:
            pyxel.tri(sx + 128, sy + 12, sx + 154, sy + 18, sx + 128, sy + 24, body)
            pyxel.rect(sx + 136, sy + 15, 8, 4, C7)
        else:
            pyxel.tri(sx + 28, sy + 12, sx + 2, sy + 18, sx + 28, sy + 24, body)
            pyxel.rect(sx + 10, sy + 15, 8, 4, C7)

        pyxel.tri(sx + 52, sy + 14, sx + 14, sy + 38, sx + 70, sy + 20, C1)
        pyxel.tri(sx + 104, sy + 14, sx + 142, sy + 38, sx + 86, sy + 20, C1)
        pyxel.tri(sx + 62, sy + 14, sx + 28, sy - 8, sx + 78, sy + 12, C1)
        pyxel.tri(sx + 94, sy + 14, sx + 128, sy - 8, sx + 78, sy + 12, C1)

        for ex in (54, 72, 90, 108):
            pyxel.rect(sx + ex, sy + 24, 8, 5, C4)
            if self.anim % 4 < 2:
                pyxel.pset(sx + ex + 3, sy + 30, C9)

        if self.warning_phase and not self.has_dropped:
            n = max(1, self.drop_cd // 10 + 1)
            pyxel.rect(sx + 63, sy - 14, 24, 10, C0)
            pyxel.rectb(sx + 63, sy - 14, 24, 10, C8)
            pyxel.text(sx + 71, sy - 11, str(n), C10)
            if (pyxel.frame_count // 6) % 2 == 0:
                pyxel.text(sx + 54, sy - 24, "NUKE", C8)


class BossEnemy(Enemy):
    def __init__(self, x):
        super().__init__(x, 64, 76, 74, 80, 650)
        self.fire_cd = 34
        self.missile_cd = 110
        self.dir = -1
        self.wave = 0

    def update(self, game):
        self.wave += 0.04
        self.y = 64 + math.sin(self.wave) * 8
        self.x += self.dir * 0.7

        if self.x < game.player.x + 70:
            self.dir = 1
        if self.x > game.player.x + 190:
            self.dir = -1

        self.fire_cd -= 1
        self.missile_cd -= 1
        self.update_flash()

        if self.fire_cd <= 0:
            self.fire_cd = 16 if self.hp < 30 else 26
            for off in (-0.20, -0.08, 0.08, 0.20):
                dx = (game.player.x + 8) - (self.x + 38)
                dy = (game.player.y + 10) - (self.y + 24)
                ang = math.atan2(dy, dx) + off
                game.enemy_bullets.append(Bullet(
                    self.x + 38, self.y + 24,
                    math.cos(ang) * 3.2,
                    math.sin(ang) * 3.2,
                    dmg=1, color=C8, radius=2, friendly=False, life=150
                ))

        if self.missile_cd <= 0:
            self.missile_cd = 70 if self.hp < 30 else 105
            count = 4 if self.hp < 35 else 2
            for _ in range(count):
                mx = self.x + random.randint(0, self.w)
                game.enemies.append(MissileEnemy(mx, 1.9 + random.random() * 0.8))

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)

        body = C7 if self.flash > 0 else C5
        bright = C12 if self.flash == 0 else C7

        pyxel.rect(sx + 18, sy + 52, 12, 20, C4)
        pyxel.rect(sx + 46, sy + 52, 12, 20, C4)
        pyxel.rect(sx + 14, sy + 70, 18, 4, C0)
        pyxel.rect(sx + 44, sy + 70, 18, 4, C0)

        pyxel.rect(sx + 16, sy + 14, 44, 38, body)
        pyxel.rect(sx + 10, sy + 22, 56, 28, C1)

        pyxel.rect(sx, sy + 18, 16, 14, bright)
        pyxel.rect(sx + 60, sy + 18, 16, 14, bright)

        pyxel.rect(sx + 2, sy + 10, 12, 8, C8)
        pyxel.rect(sx + 62, sy + 10, 12, 8, C8)
        pyxel.pset(sx + 5, sy + 13, C10)
        pyxel.pset(sx + 9, sy + 13, C10)
        pyxel.pset(sx + 65, sy + 13, C10)
        pyxel.pset(sx + 69, sy + 13, C10)

        pyxel.rect(sx + 24, sy, 28, 16, body)
        pyxel.rect(sx + 26, sy + 2, 24, 12, C1)
        pyxel.rect(sx + 29, sy + 6, 6, 3, C8)
        pyxel.rect(sx + 41, sy + 6, 6, 3, C8)

        pyxel.circ(sx + 38, sy + 30, 8, C8)
        pyxel.circ(sx + 38, sy + 30, 5, C10)
        pyxel.circ(sx + 38, sy + 30, 2 + (pyxel.frame_count // 8) % 2, C7)

        pyxel.line(sx + 10, sy + 32, sx - 10, sy + 42, C12)
        pyxel.line(sx + 66, sy + 32, sx + 86, sy + 42, C12)
        pyxel.rect(sx - 14, sy + 40, 8, 8, C8)
        pyxel.rect(sx + 82, sy + 40, 8, 8, C8)

        pyxel.rect(70, 8, 180, 8, C0)
        fill = int(178 * self.hp / self.max_hp)
        pyxel.rect(71, 9, fill, 6, C8)
        pyxel.rectb(70, 8, 180, 8, C7)
        pyxel.text(138, 1, "BOSS: NATIONAL DESTROYER", C7)


# ------------------------------------------------------------
# PICKUPS / COINS
# ------------------------------------------------------------
class PowerUp:
    TYPES = ["rapid", "shield", "spread", "rocket", "bomb", "overdrive", "super_shield"]

    def __init__(self, x, y, kind=None):
        self.x = x
        self.y = y
        self.kind = kind or random.choice(self.TYPES)
        self.vy = 0
        self.alive = True
        self.anim = 0
        self.w = 12
        self.h = 12

    def update(self, terrain, difficulty):
        self.anim += 1
        self.vy += 0.16
        self.y += self.vy

        gy = terrain.ground_y_at(self.x + 6, difficulty)
        if not terrain.is_gap_at(self.x + 6, difficulty) and self.y >= gy - 6:
            self.y = gy - 6
            self.vy = -self.vy * 0.25
            if abs(self.vy) < 0.3:
                self.vy = 0

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y + math.sin(self.anim * 0.15) * 1.5)

        col = {
            "rapid": C9,
            "shield": C12,
            "spread": C15,
            "rocket": C10,
            "bomb": C8,
            "overdrive": C14,
            "super_shield": C7,
        }[self.kind]

        pyxel.circb(sx + 5, sy + 5, 8 + (pyxel.frame_count // 8) % 2, col)
        pyxel.rect(sx + 1, sy + 1, 10, 10, col)
        pyxel.rectb(sx, sy, 12, 12, C7)
        pyxel.pset(sx + 6, sy + 6, C7)


class Coin:
    def __init__(self, x, y, value=1):
        self.x = x
        self.y = y
        self.vx = random.uniform(-1.0, 1.0)
        self.vy = random.uniform(-2.0, -0.5)
        self.value = value
        self.alive = True
        self.life = 600
        self.anim = 0

    def update(self, game):
        self.anim += 1
        self.life -= 1
        if self.life <= 0:
            self.alive = False
            return

        self.vy += 0.10
        self.x += self.vx
        self.y += self.vy
        self.vx *= 0.98

        px = game.player.x + 8
        py = game.player.y + 10
        dx = px - self.x
        dy = py - self.y
        d = math.sqrt(dx * dx + dy * dy) if (dx != 0 or dy != 0) else 1

        magnet_r = 28 + game.player.coin_magnet * 16
        if d < magnet_r:
            self.x += dx / d * 1.6
            self.y += dy / d * 1.6

        gy = game.terrain.ground_y_at(self.x, game.difficulty_value())
        if not game.terrain.is_gap_at(self.x, game.difficulty_value()) and self.y >= gy - 3:
            self.y = gy - 3
            self.vy *= -0.25
            if abs(self.vy) < 0.25:
                self.vy = 0

        if dist2(self.x, self.y, px, py) < 12 * 12:
            self.alive = False
            game.player.run_coins += self.value
            game.coins_total += self.value
            game.add_text(self.x, self.y - 6, "+" + str(self.value) + "$", C10)
            sfx(0, 3)

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y + math.sin(self.anim * 0.18) * 1.5)
        pyxel.circ(sx, sy, 3, C10)
        pyxel.pset(sx, sy, C7)


# ------------------------------------------------------------
# PLAYER
# ------------------------------------------------------------
class Player:
    SKINS = [
        {"name": "COMMANDER", "body": C12, "cape": C8, "visor": C14},
        {"name": "WARLORD", "body": C5, "cape": C2, "visor": C12},
        {"name": "IMPERATOR", "body": C14, "cape": C8, "visor": C7},
        {"name": "NEON GENERAL", "body": C15, "cape": C12, "visor": C10},
        {"name": "IRON MARSHAL", "body": C6, "cape": C1, "visor": C8},
        {"name": "PHANTOM KING", "body": C0, "cape": C2, "visor": C15},
        {"name": "SUN EMPEROR", "body": C10, "cape": C8, "visor": C14},
        {"name": "FROST TYRANT", "body": C7, "cape": C12, "visor": C1},
    ]

    WEAPONS = ["BLASTER", "SPREAD", "LASER"]

    def __init__(self, terrain, skin_index=0, country_index=0, extra_hp_bonus=0):
        self.terrain = terrain
        self.skin_index = skin_index
        self.country_index = country_index
        self.fire_rate_bonus = 0
        self.max_hp_bonus = extra_hp_bonus
        self.coin_magnet = 0
        self.reset(40, 1)

    @property
    def skin(self):
        return self.SKINS[self.skin_index]

    @property
    def country(self):
        return COUNTRIES[self.country_index]

    @property
    def weapon(self):
        if self.spread_timer > 0:
            return "SPREAD"
        return self.WEAPONS[self.weapon_index]

    def hurtbox(self):
        return self.x + 2, self.y + 2, PLAYER_W - 4, PLAYER_H - 3

    def protected(self):
        return self.invuln > 0

    def add_score(self, pts):
        bonus = self.combo * 2
        self.score += pts + bonus
        self.combo += 1
        self.combo_timer = 120

    def take_damage(self, dmg):
        if self.protected():
            return False
        self.hp -= dmg
        self.invuln = 70
        self.hurt_flash = 10
        self.combo = 0
        self.combo_timer = 0
        return True

    def reset(self, x, difficulty):
        self.x = x
        self.y = self.terrain.ground_y_at(x + 8, difficulty) - PLAYER_H
        self.vx = 0
        self.vy = 0
        self.facing = 1
        self.on_ground = True

        self.max_hp = 5 + self.max_hp_bonus
        self.hp = self.max_hp
        self.invuln = 100

        self.speed = 1.9
        self.air_speed = 1.6
        self.jump_power = -5.5
        self.jump_count = 0
        self.max_jumps = 2

        self.dash_timer = 0
        self.dash_cd = 0
        self.dash_dir = 1

        self.weapon_index = 0
        self.shot_cd = 0
        self.shoot_anim = 0
        self.hurt_flash = 0
        self.walk_frame = 0

        self.rapid_timer = 0
        self.spread_timer = 0
        self.overdrive_timer = 0
        self.super_shield_timer = 0

        self.rockets = 3
        self.bombs = 2
        self.special_cd = 0

        self.score = 0
        self.combo = 0
        self.combo_timer = 0
        self.run_coins = 0

    def update(self, difficulty, dev_mode=False):
        if self.invuln > 0:
            self.invuln -= 1
        if self.shot_cd > 0:
            self.shot_cd -= 1
        if self.shoot_anim > 0:
            self.shoot_anim -= 1
        if self.hurt_flash > 0:
            self.hurt_flash -= 1
        if self.dash_cd > 0:
            self.dash_cd -= 1
        if self.dash_timer > 0:
            self.dash_timer -= 1
        if self.rapid_timer > 0:
            self.rapid_timer -= 1
        if self.spread_timer > 0:
            self.spread_timer -= 1
        if self.overdrive_timer > 0:
            self.overdrive_timer -= 1
        if self.super_shield_timer > 0:
            self.super_shield_timer -= 1
        if self.special_cd > 0:
            self.special_cd -= 1
        if self.combo_timer > 0:
            self.combo_timer -= 1
            if self.combo_timer <= 0:
                self.combo = 0

        if pyxel.btnp(pyxel.KEY_SHIFT):
            self.weapon_index = (self.weapon_index + 1) % len(self.WEAPONS)
            sfx(1, 4)

        move = self.speed if self.on_ground else self.air_speed

        if pyxel.btnp(pyxel.KEY_SPACE) and self.dash_cd <= 0:
            if pyxel.btn(pyxel.KEY_A):
                self.dash_dir = -1
            elif pyxel.btn(pyxel.KEY_D):
                self.dash_dir = 1
            else:
                self.dash_dir = self.facing
            self.dash_timer = 8
            self.dash_cd = 30
            self.invuln = max(self.invuln, 8)
            sfx(1, 4)

        if self.dash_timer > 0:
            self.vx = self.dash_dir * 4.8
        else:
            self.vx = 0
            if pyxel.btn(pyxel.KEY_A):
                self.vx = -move
                self.facing = -1
            if pyxel.btn(pyxel.KEY_D):
                self.vx = move
                self.facing = 1

        if pyxel.btnp(pyxel.KEY_W):
            if self.on_ground:
                self.vy = self.jump_power
                self.on_ground = False
                self.jump_count = 1
            elif self.jump_count < self.max_jumps:
                self.vy = self.jump_power + 0.2
                self.jump_count += 1

        if (not self.on_ground) and pyxel.btn(pyxel.KEY_S):
            self.vy += 0.45

        self.vy = min(self.vy + GRAVITY, MAX_FALL)
        self.x += self.vx
        self.y += self.vy
        if self.x < 0:
            self.x = 0

        foot_l = self.x + 2
        foot_m = self.x + PLAYER_W / 2
        foot_r = self.x + PLAYER_W - 2

        gaps = [
            self.terrain.is_gap_at(foot_l, difficulty),
            self.terrain.is_gap_at(foot_m, difficulty),
            self.terrain.is_gap_at(foot_r, difficulty),
        ]
        grounds = [
            self.terrain.ground_y_at(foot_l, difficulty),
            self.terrain.ground_y_at(foot_m, difficulty),
            self.terrain.ground_y_at(foot_r, difficulty),
        ]

        supported = (not gaps[1]) or ((not gaps[0]) and (not gaps[2]))
        target_ground = min(grounds)

        if supported and self.y >= target_ground - PLAYER_H:
            self.y = target_ground - PLAYER_H
            self.vy = 0
            self.on_ground = True
            self.jump_count = 0
        else:
            self.on_ground = False

        if self.on_ground and abs(self.vx) > 0.1:
            self.walk_frame += 1
        else:
            self.walk_frame = 0

        if dev_mode:
            self.hp = self.max_hp
            self.rockets = 99
            self.bombs = 99
            self.invuln = max(self.invuln, 2)

    def fire(self):
        if self.shot_cd > 0:
            return [], []

        base_rate = 8 if self.rapid_timer > 0 else 11
        if self.weapon == "SPREAD":
            base_rate += 4
        if self.overdrive_timer > 0:
            base_rate -= 3

        rate = max(4, base_rate - self.fire_rate_bonus)
        self.shot_cd = rate
        self.shoot_anim = 5

        px = self.x + 8
        py = self.y + 10

        if pyxel.btn(pyxel.KEY_LEFT):
            dx, dy = -1, 0
            self.facing = -1
        elif pyxel.btn(pyxel.KEY_RIGHT):
            dx, dy = 1, 0
            self.facing = 1
        elif pyxel.btn(pyxel.KEY_UP):
            dx, dy = 0, -1
        elif pyxel.btn(pyxel.KEY_DOWN):
            dx, dy = 0, 1
        else:
            dx, dy = self.facing, 0

        ang = math.atan2(dy, dx)

        bullets = []
        effects = [Explosion(px + dx * 6, py + dy * 2, 4, "muzzle")]

        if self.weapon == "BLASTER":
            for off in (-0.04, 0.0, 0.04):
                a = ang + off
                bullets.append(Bullet(
                    px, py,
                    math.cos(a) * 7.0,
                    math.sin(a) * 7.0,
                    dmg=2 if self.overdrive_timer > 0 else 1,
                    color=C10,
                    radius=2,
                    friendly=True,
                    life=95
                ))

        elif self.weapon == "SPREAD":
            for off in (-0.18, -0.09, 0, 0.09, 0.18):
                a = ang + off
                bullets.append(Bullet(
                    px, py,
                    math.cos(a) * 5.8,
                    math.sin(a) * 5.8,
                    dmg=1,
                    color=C15,
                    radius=2,
                    friendly=True,
                    life=76
                ))
        else:
            bullets.append(Bullet(
                px, py,
                math.cos(ang) * 9.2,
                math.sin(ang) * 9.2,
                dmg=3 if self.overdrive_timer > 0 else 2,
                color=C12,
                radius=2,
                friendly=True,
                life=75
            ))
            bullets.append(Bullet(
                px, py,
                math.cos(ang) * 8.6,
                math.sin(ang) * 8.6,
                dmg=1,
                color=C7,
                radius=1,
                friendly=True,
                life=55
            ))

        sfx(0, 0)
        return bullets, effects

    def fire_rocket(self):
        if self.rockets <= 0 or self.special_cd > 0:
            return None
        self.rockets -= 1
        self.special_cd = 16
        sfx(0, 5)
        return Rocket(self.x + 10, self.y + 8)

    def fire_bomb(self):
        if self.bombs <= 0 or self.special_cd > 0:
            return None
        self.bombs -= 1
        self.special_cd = 18
        return BombProjectile(self.x + 8, self.y + 6, self.facing * 2.2, -3.2, 22)

    def draw(self, cam_x):
        sx = int(self.x - cam_x)
        sy = int(self.y)

        if self.hurt_flash > 0 and (pyxel.frame_count // 2) % 2 == 0:
            return

        body = self.skin["body"]
        cape = self.skin["cape"]
        visor = self.skin["visor"]

        dark_body = max(1, body - 1)
        bright_body = min(15, body + 1)

        step = 0
        if self.on_ground and abs(self.vx) > 0.1:
            step = 1 if (self.walk_frame // 5) % 2 == 0 else -1

        if self.invuln > 0 and (pyxel.frame_count // 4) % 2 == 0:
            col = C7 if self.super_shield_timer > 0 else C12
            pyxel.circb(sx + 8, sy + 11, 13, col)
            pyxel.circb(sx + 8, sy + 11, 11, C6)

        # flag trail during dash
        if self.dash_timer > 0:
            for i in range(1, 5):
                ox = sx - self.facing * i * 5
                pole_x = ox + 19 if self.facing == 1 else ox - 7
                pyxel.line(pole_x, sy + 4, pole_x, sy + 18, C5)
                draw_flag(self.country["key"], pole_x + (1 if self.facing == 1 else -13), sy + 5, 10, 6, pyxel.frame_count * 0.3 + i)

        cape_wave = 1 if (pyxel.frame_count // 4) % 2 == 0 else -1
        pyxel.tri(sx + 4, sy + 8, sx - 4, sy + 21 + cape_wave, sx + 5, sy + 20, cape)
        pyxel.tri(sx + 12, sy + 8, sx + 20, sy + 21 - cape_wave, sx + 11, sy + 20, cape)
        pyxel.line(sx + 6, sy + 8, sx + 3, sy + 20, dark_body)
        pyxel.line(sx + 10, sy + 8, sx + 13, sy + 20, dark_body)

        # visible flag behind player
        pole_x = sx + 19 if self.facing == 1 else sx - 7
        pyxel.line(pole_x, sy + 2, pole_x, sy + 22, C4)
        if self.facing == 1:
            draw_flag(self.country["key"], pole_x + 1, sy + 3, 14, 8, pyxel.frame_count * 0.25)
        else:
            draw_flag(self.country["key"], pole_x - 15, sy + 3, 14, 8, pyxel.frame_count * 0.25)

        pyxel.rect(sx + 3, sy - 1, 10, 7, C0)
        pyxel.rect(sx + 1, sy + 5, 14, 13, C0)

        pyxel.rect(sx + 4, sy, 8, 5, C11)
        pyxel.rect(sx + 3, sy + 1, 10, 3, C11)
        pyxel.rect(sx + 3, sy - 2, 10, 2, visor)
        pyxel.pset(sx + 7, sy - 3, C14)
        pyxel.pset(sx + 8, sy - 3, C14)

        if self.facing == 1:
            pyxel.rect(sx + 7, sy + 2, 4, 2, visor)
            pyxel.pset(sx + 10, sy + 2, C7)
        else:
            pyxel.rect(sx + 5, sy + 2, 4, 2, visor)
            pyxel.pset(sx + 5, sy + 2, C7)

        pyxel.rect(sx + 1, sy + 6, 4, 4, bright_body)
        pyxel.rect(sx + 11, sy + 6, 4, 4, bright_body)
        pyxel.line(sx + 1, sy + 6, sx + 4, sy + 5, C7)
        pyxel.line(sx + 11, sy + 6, sx + 14, sy + 5, C7)

        pyxel.rect(sx + 4, sy + 6, 8, 11, body)
        pyxel.rect(sx + 5, sy + 7, 6, 9, bright_body)
        pyxel.line(sx + 8, sy + 6, sx + 8, sy + 16, dark_body)

        pyxel.line(sx + 4, sy + 7, sx + 11, sy + 15, C8)
        pyxel.line(sx + 5, sy + 7, sx + 12, sy + 14, C14)
        pyxel.pset(sx + 10, sy + 10, C14)
        pyxel.pset(sx + 10, sy + 11, C10)
        pyxel.pset(sx + 6, sy + 11, C7)

        pyxel.rect(sx + 4, sy + 16, 8, 2, C4)
        pyxel.pset(sx + 8, sy + 16, C14)

        if self.shoot_anim > 0:
            if self.facing == 1:
                pyxel.line(sx + 12, sy + 9, sx + 16, sy + 9, bright_body)
                pyxel.line(sx + 4, sy + 9, sx + 1, sy + 12, dark_body)
            else:
                pyxel.line(sx + 4, sy + 9, sx, sy + 9, bright_body)
                pyxel.line(sx + 12, sy + 9, sx + 15, sy + 12, dark_body)
        else:
            pyxel.line(sx + 4, sy + 9, sx + 1, sy + 13, dark_body)
            pyxel.line(sx + 12, sy + 9, sx + 15, sy + 13, dark_body)

        pyxel.rect(sx + 5, sy + 18, 3, 4, dark_body)
        pyxel.rect(sx + 9, sy + 18, 3, 4, dark_body)
        pyxel.line(sx + 6, sy + 22, sx + 5 + step, sy + 25, C5)
        pyxel.line(sx + 10, sy + 22, sx + 11 - step, sy + 25, C5)
        pyxel.rect(sx + 3 + step, sy + 24, 4, 2, C0)
        pyxel.rect(sx + 9 - step, sy + 24, 4, 2, C0)

        if self.facing == 1:
            pyxel.rect(sx + 14, sy + 8, 4, 2, C5)
            pyxel.rect(sx + 17, sy + 8, 3, 1, C7)
        else:
            pyxel.rect(sx - 2, sy + 8, 4, 2, C5)
            pyxel.rect(sx - 3, sy + 8, 3, 1, C7)

        if self.shoot_anim > 0:
            mx = sx + 20 if self.facing == 1 else sx - 3
            my = sy + 9
            pyxel.circ(mx, my, 3, C9)
            pyxel.circ(mx, my, 2, C10)
            pyxel.pset(mx, my, C7)


# ------------------------------------------------------------
# GAME
# ------------------------------------------------------------
class Game:
    def __init__(self):
        pyxel.init(SCREEN_W, SCREEN_H, title="Supreme Leader Survival", fps=FPS)
        setup_audio()

        self.save_data = load_save()

        self.terrain = Terrain()

        self.mode_names = ["CAMPAIGN", "SURVIVAL", "MULTIPLAYER"]
        self.mode_index = 0
        self.skin_index = 0
        self.country_index = self.save_data["selected_country"]

        self.state = "intro_country"
        self.intro_timer = 0
        self.intro_flash = 0
        self.intro_bombs = []
        self.dev_mode = False
        self.dev_buffer = ""
        self.tutorial_step = 0
        self.pause_index = 0
        self.story_progress_index = 0

        self.mp_mode_index = 0
        self.mp_mode_names = ["SURVIVAL ROTATION", "CAMPAIGN ROTATION"]
        self.mp_player_count = 2
        self.mp_current_player = 0
        self.mp_scores = []
        self.mp_segment_progress = 0
        self.mp_campaign_segments_left = 0

        self.level = 1
        self.distance = 0
        self.total_distance = 0
        self.checkpoint_x = 40
        self.checkpoint_level = 1
        self.survival_frames = 0
        self.best_survival = self.save_data["best_survival"]
        self.level_clear_timer = 0

        self.coins_total = self.save_data["coins_total"]
        self.unlocked_skins = self.save_data["unlocked_skins"]
        self.extra_hp_bonus = self.save_data["extra_hp_bonus"]

        self.player = Player(self.terrain, self.skin_index, self.country_index, self.extra_hp_bonus)

        self.player_bullets = []
        self.enemy_bullets = []
        self.rockets = []
        self.bombs = []
        self.enemies = []
        self.effects = []
        self.particles = []
        self.texts = []
        self.powerups = []
        self.coins = []

        self.cam_x = 0
        self.shake_timer = 0
        self.shake_power = 0
        self.shake_x = 0
        self.shake_y = 0

        self.stars = []
        for i in range(85):
            self.stars.append({
                "x": random.randint(0, 4000),
                "y": random.randint(0, 90),
                "color": C7 if i % 5 == 0 else C6,
                "speed": random.choice([0.10, 0.15, 0.22]),
            })

        pyxel.run(self.update, self.draw)

    def current_country(self):
        return COUNTRIES[self.country_index]

    def is_campaign(self):
        return self.mode_names[self.mode_index] == "CAMPAIGN"

    def is_survival(self):
        return self.mode_names[self.mode_index] == "SURVIVAL"

    def is_multiplayer(self):
        return self.mode_names[self.mode_index] == "MULTIPLAYER"

    def progress_ratio(self):
        return clamp(self.distance / LEVEL_DISTANCE, 0, 1)

    def difficulty_value(self):
        if self.is_campaign() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION"):
            return self.level + self.progress_ratio() * 2.2
        return 1 + self.survival_frames / 750.0

    def danger_scale(self):
        if self.is_campaign() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION"):
            return 1 + (self.level - 1) * 0.24 + self.progress_ratio() * 1.30
        return 1 + self.survival_frames / 650.0

    def reset_run(self):
        self.level = 1
        self.distance = 0
        self.total_distance = 0
        self.checkpoint_x = 40
        self.checkpoint_level = 1
        self.survival_frames = 0

        self.terrain.reset()
        old_player = self.player
        self.player = Player(self.terrain, self.skin_index, self.country_index, self.extra_hp_bonus)
        self.player.fire_rate_bonus = old_player.fire_rate_bonus
        self.player.max_hp_bonus = max(old_player.max_hp_bonus, self.extra_hp_bonus)
        self.player.coin_magnet = old_player.coin_magnet
        self.player.reset(self.checkpoint_x, self.difficulty_value())
        self.player.max_hp_bonus = max(self.player.max_hp_bonus, self.extra_hp_bonus)
        self.player.max_hp = 5 + self.player.max_hp_bonus
        self.player.hp = self.player.max_hp

        self.player_bullets.clear()
        self.enemy_bullets.clear()
        self.rockets.clear()
        self.bombs.clear()
        self.enemies.clear()
        self.effects.clear()
        self.particles.clear()
        self.texts.clear()
        self.powerups.clear()
        self.coins.clear()

        self.cam_x = 0
        self.state = "tutorial"
        stop_story_music()

    def reset_multiplayer_session(self):
        self.mp_scores = [0 for _ in range(self.mp_player_count)]
        self.mp_current_player = 0
        self.mp_segment_progress = 0
        self.mp_campaign_segments_left = CAMPAIGN_LEVELS
        self.reset_run()

    def respawn_or_game_over(self):
        self.spawn_explosion(self.player.x + 8, self.player.y + 10, 18, "normal")
        self.add_shake(12, 4)

        if self.is_campaign():
            self.level = self.checkpoint_level
            self.distance = 0
            self.terrain.reset()

            self.player_bullets.clear()
            self.enemy_bullets.clear()
            self.rockets.clear()
            self.bombs.clear()
            self.enemies.clear()
            self.effects.clear()
            self.particles.clear()
            self.texts.clear()
            self.powerups.clear()
            self.coins.clear()

            self.player.reset(self.checkpoint_x, self.difficulty_value())
            self.player.max_hp_bonus = max(self.player.max_hp_bonus, self.extra_hp_bonus)
            self.player.max_hp = 5 + self.player.max_hp_bonus
            self.player.hp = self.player.max_hp

        elif self.is_multiplayer():
            self.mp_scores[self.mp_current_player] += self.player.score
            self.mp_current_player = (self.mp_current_player + 1) % self.mp_player_count

            self.player_bullets.clear()
            self.enemy_bullets.clear()
            self.rockets.clear()
            self.bombs.clear()
            self.enemies.clear()
            self.effects.clear()
            self.particles.clear()
            self.texts.clear()
            self.powerups.clear()
            self.coins.clear()
            self.terrain.reset()

            if self.mp_mode_names[self.mp_mode_index] == "SURVIVAL ROTATION":
                self.player.reset(self.checkpoint_x, self.difficulty_value())
            else:
                self.player.reset(self.checkpoint_x + self.mp_segment_progress, self.difficulty_value())

            self.player.max_hp_bonus = max(self.player.max_hp_bonus, self.extra_hp_bonus)
            self.player.max_hp = 5 + self.player.max_hp_bonus
            self.player.hp = self.player.max_hp
            self.add_text(self.player.x + 10, 24, "PLAYER " + str(self.mp_current_player + 1), C14)

        else:
            self.best_survival = max(self.best_survival, self.survival_frames // 60)
            self.save_data["best_survival"] = self.best_survival
            self.save_data["coins_total"] = self.coins_total
            self.save_data["unlocked_skins"] = self.unlocked_skins
            self.save_data["selected_country"] = self.country_index
            self.save_data["extra_hp_bonus"] = self.extra_hp_bonus
            save_save(self.save_data)
            self.state = "game_over"
            play_story_music()

    def next_level(self):
        self.level += 1
        self.distance = 0
        self.checkpoint_level = self.level
        self.checkpoint_x = self.player.x + 30

        self.player_bullets.clear()
        self.enemy_bullets.clear()
        self.rockets.clear()
        self.bombs.clear()
        self.enemies.clear()
        self.effects.clear()
        self.particles.clear()
        self.texts.clear()
        self.powerups.clear()
        self.coins.clear()
        self.terrain.reset()

        if self.level > CAMPAIGN_LEVELS:
            self.state = "victory"
            play_story_music()
        else:
            self.player.reset(self.checkpoint_x, self.difficulty_value())
            self.player.max_hp_bonus = max(self.player.max_hp_bonus, self.extra_hp_bonus)
            self.player.max_hp = 5 + self.player.max_hp_bonus
            self.player.hp = self.player.max_hp
            self.story_progress_index += 1
            self.state = "story_update"
            self.intro_timer = 0
            self.intro_bombs = []
            play_story_music()

    def add_shake(self, timer, power):
        self.shake_timer = max(self.shake_timer, timer)
        self.shake_power = max(self.shake_power, power)

    def update_shake(self):
        if self.shake_timer > 0:
            self.shake_timer -= 1
            self.shake_x = random.randint(-self.shake_power, self.shake_power)
            self.shake_y = random.randint(-self.shake_power, self.shake_power)
        else:
            self.shake_x = 0
            self.shake_y = 0
            self.shake_power = 0

    def spawn_explosion(self, x, y, power=12, kind="normal"):
        self.effects.append(Explosion(x, y, power, kind))
        count = power + 4
        for _ in range(count):
            ang = random.random() * math.tau
            sp = random.uniform(0.6, 2.8 if kind != "nuke" else 4.0)
            col = random.choice([C8, C9, C10, C7])
            self.particles.append(Particle(
                x, y,
                math.cos(ang) * sp,
                math.sin(ang) * sp - random.random() * 0.5,
                col, random.randint(10, 24),
                size=random.choice([1, 1, 2]),
                gravity=0.05
            ))

    def spawn_smoke(self, x, y, amount=3):
        for _ in range(amount):
            self.particles.append(Particle(
                x + random.randint(-2, 2),
                y + random.randint(-2, 2),
                random.uniform(-0.5, 0.5),
                random.uniform(-1.1, -0.3),
                random.choice([C5, C6, C2]),
                random.randint(12, 24),
                size=random.choice([1, 2]),
                gravity=-0.01
            ))

    def add_text(self, x, y, text, color=C7):
        self.texts.append(FlashText(x, y, text, color))
        # --------------------------------------------------------
    # DEV MODE
    # --------------------------------------------------------
    def update_dev_buffer(self):
        keys = [
            (pyxel.KEY_D, "D"),
            (pyxel.KEY_E, "E"),
            (pyxel.KEY_V, "V"),
        ]
        for key, ch in keys:
            if pyxel.btnp(key):
                self.dev_buffer += ch
                if len(self.dev_buffer) > 8:
                    self.dev_buffer = self.dev_buffer[-8:]

        if "DEV" in self.dev_buffer:
            self.dev_mode = not self.dev_mode
            self.dev_buffer = ""
            self.add_text(self.player.x + 30, 30, "DEV ON" if self.dev_mode else "DEV OFF", C14 if self.dev_mode else C8)

    # --------------------------------------------------------
    # SPAWNING
    # --------------------------------------------------------
    def missile_spawn_chance(self):
        d = self.danger_scale()
        return min(0.006 + d * 0.005, 0.12)

    def drone_spawn_chance(self):
        d = self.danger_scale()
        return min(0.002 + d * 0.0026, 0.04)

    def turret_spawn_chance(self):
        d = self.difficulty_value()
        if d < 2.2:
            return 0
        return min(0.001 + d * 0.0014, 0.012)

    def bomber_spawn_chance(self):
        d = self.difficulty_value()
        if d < 4.8:
            return 0
        return min(0.00035 + d * 0.00022, 0.0024)

    def maybe_spawn_enemies(self):
        boss_alive = False
        for e in self.enemies:
            if isinstance(e, BossEnemy):
                boss_alive = True
                break

        is_campaign_like = self.is_campaign() or (
            self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION"
        )

        if is_campaign_like and not boss_alive and self.level in (3, 6, 9) and self.progress_ratio() > 0.76:
            self.enemies.append(BossEnemy(self.player.x + 240))
            self.add_text(self.player.x + 80, 36, "BOSS INCOMING", C8)
            sfx(0, 7)
            return

        if boss_alive:
            return

        if random.random() < self.bomber_spawn_chance():
            direction = random.choice([-1, 1])
            x = self.cam_x - 180 if direction == 1 else self.cam_x + SCREEN_W + 40
            y = random.randint(22, 46)
            drop_x = self.player.x + random.randint(80, 200)
            self.enemies.append(BomberEnemy(x, y, direction, drop_x))
            self.add_text(self.player.x + 120, 26, "STRATEGIC BOMBER", C8)

        if random.random() < self.missile_spawn_chance():
            x = self.cam_x + random.randint(0, SCREEN_W - 10)
            sp = 1.3 + self.danger_scale() * 0.42 + random.random() * 0.4
            self.enemies.append(MissileEnemy(x, sp))

        if random.random() < self.drone_spawn_chance():
            side = random.choice([-1, 1])
            x = self.cam_x - 30 if side < 0 else self.cam_x + SCREEN_W + 30
            y = random.randint(70, 138)
            self.enemies.append(
                DroneEnemy(
                    x,
                    y,
                    self.player.x + 8,
                    self.player.y + 8,
                    1.1 + self.danger_scale() * 0.18
                )
            )

        if random.random() < self.turret_spawn_chance():
            tx = self.player.x + random.randint(160, 320)
            gy = self.terrain.ground_y_at(tx + 8, self.difficulty_value())
            if not self.terrain.is_gap_at(tx + 8, self.difficulty_value()):
                self.enemies.append(TurretEnemy(tx, gy))

    # --------------------------------------------------------
    # INPUT / CAMERA
    # --------------------------------------------------------
    def update_camera(self):
        target = self.player.x - SCREEN_W // 3
        if self.is_survival() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "SURVIVAL ROTATION"):
            target = self.player.x - SCREEN_W // 2

        if target < 0:
            target = 0
        self.cam_x += (target - self.cam_x) * 0.14

    def handle_player_fire(self):
        if (
            pyxel.btn(pyxel.KEY_LEFT)
            or pyxel.btn(pyxel.KEY_RIGHT)
            or pyxel.btn(pyxel.KEY_UP)
            or pyxel.btn(pyxel.KEY_DOWN)
        ):
            bullets, effects = self.player.fire()
            if bullets:
                self.player_bullets.extend(bullets)
                self.effects.extend(effects)
                sfx(0, 0)

        if pyxel.btnp(pyxel.KEY_K):
            rocket = self.player.fire_rocket()
            if rocket:
                self.rockets.append(rocket)

        if pyxel.btnp(pyxel.KEY_J):
            bomb = self.player.fire_bomb()
            if bomb:
                self.bombs.append(bomb)

    # --------------------------------------------------------
    # COLLISIONS
    # --------------------------------------------------------
    def player_hit(self, dmg=1):
        if self.dev_mode:
            return

        if self.player.take_damage(dmg):
            self.add_shake(6, 2)
            self.spawn_explosion(self.player.x + 8, self.player.y + 10, 7, "muzzle")
            sfx(1, 1)
            if self.player.hp <= 0:
                self.respawn_or_game_over()

    def kill_enemy(self, enemy):
        kind = "nuke" if isinstance(enemy, HeavyBombEnemy) else "normal"
        power = 26 if isinstance(enemy, BossEnemy) else 12
        self.spawn_explosion(enemy.x + enemy.w / 2, enemy.y + enemy.h / 2, power, kind)
        self.player.add_score(enemy.score)
        self.add_text(enemy.x, enemy.y - 10, "+" + str(enemy.score), C10)
        sfx(0, 2)

        coin_count = 1
        if isinstance(enemy, BossEnemy):
            coin_count = 12
        elif isinstance(enemy, BomberEnemy):
            coin_count = 5
        elif isinstance(enemy, HeavyBombEnemy):
            coin_count = 3
        elif isinstance(enemy, TurretEnemy):
            coin_count = 2

        for _ in range(coin_count):
            self.coins.append(Coin(enemy.x + enemy.w / 2, enemy.y + enemy.h / 2, 1))

        if isinstance(enemy, BossEnemy):
            self.add_text(enemy.x, enemy.y - 26, "BOSS DOWN", C14)
            if self.is_campaign() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION"):
                self.distance = LEVEL_DISTANCE
            return

        if random.random() < 0.22:
            self.powerups.append(PowerUp(enemy.x + enemy.w / 2, enemy.y + enemy.h / 2))

    def bullet_hits_enemy(self, bullet, enemy):
        if rects_overlap(bullet.x - 2, bullet.y - 2, 4, 4, enemy.x, enemy.y, enemy.w, enemy.h):
            dead = enemy.hurt(bullet.dmg)
            bullet.alive = False
            self.spawn_smoke(bullet.x, bullet.y, 2)
            if dead:
                self.kill_enemy(enemy)
            return True
        return False

    def handle_collisions(self):
        px, py, pw, ph = self.player.hurtbox()

        for b in self.player_bullets:
            if not b.alive:
                continue
            for e in self.enemies:
                if e.alive and self.bullet_hits_enemy(b, e):
                    break

        for r in self.rockets:
            if not r.alive:
                continue
            for e in self.enemies:
                if not e.alive:
                    continue
                cx = e.x + e.w / 2
                cy = e.y + e.h / 2
                if dist2(r.x, r.y, cx, cy) < 14 * 14:
                    r.alive = False
                    dead = e.hurt(4)
                    self.spawn_explosion(r.x, r.y, 10, "normal")
                    self.add_shake(5, 2)
                    if dead:
                        self.kill_enemy(e)
                    break

        for b in self.bombs:
            if not b.alive:
                continue

            hit_enemy = False
            for e in self.enemies:
                if not e.alive:
                    continue
                if rects_overlap(b.x - 5, b.y - 5, 10, 10, e.x, e.y, e.w, e.h):
                    b.alive = False
                    self.spawn_explosion(b.x, b.y, b.power, "nuke")
                    self.add_shake(10, 3)
                    dead = e.hurt(5)
                    if dead:
                        self.kill_enemy(e)
                    hit_enemy = True
                    break

            if hit_enemy:
                continue

            gy = self.terrain.ground_y_at(b.x, self.difficulty_value())
            gap = self.terrain.is_gap_at(b.x, self.difficulty_value())
            if not gap and b.y >= gy:
                b.alive = False
                self.spawn_explosion(b.x, gy, b.power, "nuke")
                self.terrain.add_crater(b.x, 24)
                self.add_shake(12, 4)

                for e in self.enemies:
                    if e.alive:
                        cx = e.x + e.w / 2
                        cy = e.y + e.h / 2
                        if dist2(b.x, gy, cx, cy) < 40 * 40:
                            dead = e.hurt(4)
                            if dead:
                                self.kill_enemy(e)

        for b in self.enemy_bullets:
            if b.alive and rects_overlap(px, py, pw, ph, b.x - 2, b.y - 2, 4, 4):
                b.alive = False
                self.player_hit(1)

        for e in self.enemies:
            if e.alive and rects_overlap(px, py, pw, ph, e.x, e.y, e.w, e.h):
                self.player_hit(1)

        for e in self.enemies:
            if not e.alive:
                continue
            if isinstance(e, (MissileEnemy, HeavyBombEnemy)):
                gy = self.terrain.ground_y_at(e.x + e.w / 2, self.difficulty_value())
                gap = self.terrain.is_gap_at(e.x + e.w / 2, self.difficulty_value())
                if not gap and e.y + e.h >= gy:
                    if isinstance(e, HeavyBombEnemy):
                        self.spawn_explosion(e.x + e.w / 2, gy, 22, "nuke")
                        self.terrain.add_crater(e.x + e.w / 2, 26)
                        self.add_shake(18, 5)
                        if dist2(self.player.x + 8, self.player.y + 10, e.x + e.w / 2, gy) < 36 * 36:
                            self.player_hit(2)
                    else:
                        self.spawn_explosion(e.x + e.w / 2, gy, 11, "normal")
                        self.add_shake(6, 2)
                        if dist2(self.player.x + 8, self.player.y + 10, e.x + e.w / 2, gy) < 16 * 16:
                            self.player_hit(1)
                    e.alive = False

        for p in self.powerups:
            if p.alive and rects_overlap(px, py, pw, ph, p.x, p.y, 12, 12):
                p.alive = False
                sfx(0, 3)

                if p.kind == "rapid":
                    self.player.rapid_timer = 420
                    self.add_text(self.player.x, self.player.y - 10, "RAPID FIRE++", C9)
                elif p.kind == "shield":
                    self.player.invuln = 220
                    self.add_text(self.player.x, self.player.y - 10, "SHIELD", C12)
                elif p.kind == "spread":
                    self.player.spread_timer = 300
                    self.add_text(self.player.x, self.player.y - 10, "SPREAD", C15)
                elif p.kind == "rocket":
                    self.player.rockets += 3
                    self.add_text(self.player.x, self.player.y - 10, "+3 ROCKETS", C10)
                elif p.kind == "bomb":
                    self.player.bombs += 2
                    self.add_text(self.player.x, self.player.y - 10, "+2 BOMBS", C8)
                elif p.kind == "overdrive":
                    self.player.overdrive_timer = 360
                    self.player.rapid_timer = max(self.player.rapid_timer, 180)
                    self.add_text(self.player.x, self.player.y - 10, "OVERDRIVE", C14)
                elif p.kind == "super_shield":
                    self.player.super_shield_timer = 360
                    self.player.invuln = 360
                    self.add_text(self.player.x, self.player.y - 10, "SUPER SHIELD", C7)

        self.player_bullets = [b for b in self.player_bullets if b.alive]
        self.enemy_bullets = [b for b in self.enemy_bullets if b.alive]
        self.rockets = [r for r in self.rockets if r.alive]
        self.bombs = [b for b in self.bombs if b.alive]
        self.enemies = [e for e in self.enemies if e.alive]
        self.powerups = [p for p in self.powerups if p.alive]
        self.coins = [c for c in self.coins if c.alive]

    # --------------------------------------------------------
    # INTRO / MENU / STORY / TUTORIAL / PAUSE
    # --------------------------------------------------------
    def update_intro_country(self):
        if pyxel.btnp(pyxel.KEY_LEFT):
            self.country_index = (self.country_index - 1) % len(COUNTRIES)
        if pyxel.btnp(pyxel.KEY_RIGHT):
            self.country_index = (self.country_index + 1) % len(COUNTRIES)
        if pyxel.btnp(pyxel.KEY_RETURN):
            self.save_data["selected_country"] = self.country_index
            save_save(self.save_data)
            self.intro_timer = 0
            self.intro_bombs = []
            self.state = "intro_attack"
            play_story_music()

    def update_intro_attack(self):
        self.intro_timer += 1

        if self.intro_timer % 12 == 0 and self.intro_timer < 250:
            self.intro_bombs.append({
                "x": random.randint(32, SCREEN_W - 32),
                "y": -10,
                "vy": random.uniform(2.0, 3.4),
            })

        new_bombs = []
        for b in self.intro_bombs:
            b["y"] += b["vy"]
            if b["y"] >= 118:
                self.spawn_explosion(b["x"], 118, 18, "nuke")
                self.intro_flash = 10
                self.add_shake(8, 3)
                sfx(0, 2)
            else:
                new_bombs.append(b)
        self.intro_bombs = new_bombs

        if self.intro_flash > 0:
            self.intro_flash -= 1

        for e in self.effects:
            e.update()
        for p in self.particles:
            p.update()

        self.effects = [e for e in self.effects if e.alive]
        self.particles = [p for p in self.particles if p.alive]

        if self.intro_timer > 320 and pyxel.btnp(pyxel.KEY_RETURN):
            self.state = "menu"
            stop_story_music()

    def update_menu(self):
        if pyxel.btnp(pyxel.KEY_LEFT):
            self.mode_index = (self.mode_index - 1) % len(self.mode_names)
        if pyxel.btnp(pyxel.KEY_RIGHT):
            self.mode_index = (self.mode_index + 1) % len(self.mode_names)

        if pyxel.btnp(pyxel.KEY_UP):
            self.skin_index = (self.skin_index - 1) % len(Player.SKINS)
        if pyxel.btnp(pyxel.KEY_DOWN):
            self.skin_index = (self.skin_index + 1) % len(Player.SKINS)

        if pyxel.btnp(pyxel.KEY_U):
            if self.skin_index not in self.unlocked_skins:
                cost = SKIN_PRICES[self.skin_index]
                if self.coins_total >= cost:
                    self.coins_total -= cost
                    self.unlocked_skins.append(self.skin_index)
                    self.save_data["coins_total"] = self.coins_total
                    self.save_data["unlocked_skins"] = self.unlocked_skins
                    save_save(self.save_data)

        if self.is_multiplayer():
            if pyxel.btnp(pyxel.KEY_A):
                self.mp_player_count = max(2, self.mp_player_count - 1)
            if pyxel.btnp(pyxel.KEY_D):
                self.mp_player_count = min(8, self.mp_player_count + 1)
            if pyxel.btnp(pyxel.KEY_W) or pyxel.btnp(pyxel.KEY_S):
                self.mp_mode_index = (self.mp_mode_index + 1) % len(self.mp_mode_names)

        if pyxel.btnp(pyxel.KEY_RETURN):
            if self.skin_index in self.unlocked_skins:
                self.player.skin_index = self.skin_index
                self.player.country_index = self.country_index
                if self.is_multiplayer():
                    self.reset_multiplayer_session()
                else:
                    self.reset_run()

    def update_tutorial(self):
        # Skip
        if pyxel.btnp(pyxel.KEY_RETURN):
            self.state = "playing"
            return

        # must actually do the actions
        if self.tutorial_step == 0:
            if pyxel.btn(pyxel.KEY_D):
                self.player.x += 1.5
                if self.player.x > 70:
                    self.tutorial_step = 1

        elif self.tutorial_step == 1:
            if pyxel.btnp(pyxel.KEY_W):
                self.tutorial_step = 2

        elif self.tutorial_step == 2:
            if pyxel.btnp(pyxel.KEY_SPACE):
                self.tutorial_step = 3

        elif self.tutorial_step == 3:
            if pyxel.btn(pyxel.KEY_RIGHT) or pyxel.btn(pyxel.KEY_LEFT) or pyxel.btn(pyxel.KEY_UP) or pyxel.btn(pyxel.KEY_DOWN):
                bullets, fx = self.player.fire()
                if bullets:
                    self.player_bullets.extend(bullets)
                    self.effects.extend(fx)
                    self.tutorial_step = 4

        elif self.tutorial_step == 4:
            if pyxel.btnp(pyxel.KEY_J):
                bomb = self.player.fire_bomb()
                if bomb:
                    self.bombs.append(bomb)
                    self.tutorial_step = 5

        elif self.tutorial_step == 5:
            if pyxel.btnp(pyxel.KEY_K):
                rocket = self.player.fire_rocket()
                if rocket:
                    self.rockets.append(rocket)
                    self.tutorial_step = 6

        elif self.tutorial_step == 6:
            if pyxel.btnp(pyxel.KEY_SHIFT):
                self.player.weapon_index = (self.player.weapon_index + 1) % len(self.player.WEAPONS)
                self.tutorial_step = 7

        elif self.tutorial_step == 7:
            self.state = "playing"

        for b in self.player_bullets:
            b.update()
        for e in self.effects:
            e.update()
        for r in self.rockets:
            r.update([])
        for b in self.bombs:
            b.update()

        self.player_bullets = [b for b in self.player_bullets if b.alive]
        self.effects = [e for e in self.effects if e.alive]
        self.rockets = [r for r in self.rockets if r.alive]
        self.bombs = [b for b in self.bombs if b.alive]

    def update_story_update(self):
        self.intro_timer += 1

        if self.intro_timer % 18 == 0 and self.intro_timer < 120:
            self.intro_bombs.append({
                "x": random.randint(50, SCREEN_W - 50),
                "y": -10,
                "vy": random.uniform(1.8, 3.0),
            })

        new_bombs = []
        for b in self.intro_bombs:
            b["y"] += b["vy"]
            if b["y"] >= 118:
                self.spawn_explosion(b["x"], 118, 14, "normal")
                self.intro_flash = 5
            else:
                new_bombs.append(b)
        self.intro_bombs = new_bombs

        if self.intro_flash > 0:
            self.intro_flash -= 1

        for e in self.effects:
            e.update()
        for p in self.particles:
            p.update()

        self.effects = [e for e in self.effects if e.alive]
        self.particles = [p for p in self.particles if p.alive]

        if self.intro_timer > 180 and pyxel.btnp(pyxel.KEY_RETURN):
            self.state = "playing"
            stop_story_music()

    def update_pause(self):
        if pyxel.btnp(pyxel.KEY_UP):
            self.pause_index = (self.pause_index - 1) % 2
        if pyxel.btnp(pyxel.KEY_DOWN):
            self.pause_index = (self.pause_index + 1) % 2
        if pyxel.btnp(pyxel.KEY_RETURN):
            if self.pause_index == 0:
                self.state = "playing"
            else:
                self.state = "menu"
                stop_story_music()

    # --------------------------------------------------------
    # UPDATE PLAYING / MAIN UPDATE
    # --------------------------------------------------------
    def update_playing(self):
        self.update_dev_buffer()
        self.player.update(self.difficulty_value(), self.dev_mode)

        if self.is_survival() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "SURVIVAL ROTATION"):
            self.survival_frames += 1

        if self.player.y > SCREEN_H + 40:
            self.player_hit(99)
            return

        self.handle_player_fire()
        self.maybe_spawn_enemies()

        for e in self.enemies:
            e.update(self)
            if isinstance(e, (MissileEnemy, HeavyBombEnemy)):
                self.spawn_smoke(e.x + e.w / 2, e.y + e.h, 1)

        for b in self.player_bullets:
            b.update()

        for b in self.enemy_bullets:
            b.update()

        for r in self.rockets:
            r.update(self.enemies)
            self.spawn_smoke(r.x - 2, r.y + 1, 1)

        for b in self.bombs:
            b.update()
            self.spawn_smoke(b.x, b.y, 1)

        for e in self.effects:
            e.update()

        for p in self.particles:
            p.update()

        for t in self.texts:
            t.update()

        for p in self.powerups:
            p.update(self.terrain, self.difficulty_value())

        for c in self.coins:
            c.update(self)

        self.handle_collisions()

        self.effects = [e for e in self.effects if e.alive]
        self.particles = [p for p in self.particles if p.alive]
        self.texts = [t for t in self.texts if t.alive]

        if self.player.vx > 0:
            self.distance += self.player.vx
            self.total_distance += self.player.vx
            if self.is_multiplayer():
                self.mp_segment_progress += self.player.vx

        # survival timer-based: keep inside screen
        if self.is_survival() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "SURVIVAL ROTATION"):
            left_limit = self.cam_x + 10
            right_limit = self.cam_x + SCREEN_W - 26
            self.player.x = clamp(self.player.x, left_limit, right_limit)

        self.update_camera()
        self.update_shake()

        if self.is_campaign() and self.distance >= LEVEL_DISTANCE:
            self.next_level()

        elif self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION":
            if self.distance >= 1500:
                self.mp_scores[self.mp_current_player] += self.player.score
                self.level += 1
                self.distance = 0
                self.mp_segment_progress = 0
                self.mp_campaign_segments_left -= 1
                self.mp_current_player = (self.mp_current_player + 1) % self.mp_player_count

                if self.mp_campaign_segments_left <= 0:
                    self.state = "victory"
                    play_story_music()
                else:
                    self.story_progress_index += 1
                    self.state = "story_update"
                    self.intro_timer = 0
                    self.intro_bombs = []
                    play_story_music()

    def update(self):
        if pyxel.btnp(pyxel.KEY_P):
            if self.state == "playing":
                self.state = "paused"
                self.pause_index = 0
                return

        if self.state == "intro_country":
            self.update_intro_country()
            return

        if self.state == "intro_attack":
            self.update_intro_attack()
            return

        if self.state == "menu":
            self.update_menu()
            return

        if self.state == "tutorial":
            self.update_tutorial()
            return

        if self.state == "story_update":
            self.update_story_update()
            return

        if self.state == "paused":
            self.update_pause()
            return

        if pyxel.btnp(pyxel.KEY_ESCAPE):
            if self.state == "playing":
                self.state = "menu"
                return

        if self.state == "victory":
            if pyxel.btnp(pyxel.KEY_RETURN):
                self.save_data["best_survival"] = self.best_survival
                self.save_data["coins_total"] = self.coins_total
                self.save_data["unlocked_skins"] = self.unlocked_skins
                self.save_data["selected_country"] = self.country_index
                self.save_data["extra_hp_bonus"] = self.extra_hp_bonus
                save_save(self.save_data)
                self.state = "menu"
                stop_story_music()
            return

        if self.state == "game_over":
            if pyxel.btnp(pyxel.KEY_RETURN):
                self.save_data["best_survival"] = self.best_survival
                self.save_data["coins_total"] = self.coins_total
                self.save_data["unlocked_skins"] = self.unlocked_skins
                self.save_data["selected_country"] = self.country_index
                self.save_data["extra_hp_bonus"] = self.extra_hp_bonus
                save_save(self.save_data)
                self.state = "menu"
                stop_story_music()
            return

        self.update_playing()

    # --------------------------------------------------------
    # DRAW HELPERS
    # --------------------------------------------------------
    def draw_background(self):
        pyxel.cls(C1)

        # sky glow
        pyxel.circ(258, 34, 18, C7)
        pyxel.circ(264, 30, 14, C6)

        # stars
        for s in self.stars:
            sx = int((s["x"] - self.cam_x * s["speed"]) % (SCREEN_W + 40)) - 20
            pyxel.pset(sx, s["y"], s["color"])

        # far mountains
        for i in range(-2, 18):
            wx = i * 70
            sx = int(wx - self.cam_x * 0.18)
            h = 30 + (i * 13) % 22
            pyxel.tri(sx, 140, sx + 35, 140 - h, sx + 70, 140, C2)

        # mid skyline
        for i in range(-3, 24):
            wx = i * 40
            sx = int(wx - self.cam_x * 0.32)
            h = 18 + ((i * 9) % 26)
            pyxel.rect(sx, GROUND_Y - h - 26, 24, h, C2)
            pyxel.rect(sx + 4, GROUND_Y - h - 20, 3, 4, C6)
            pyxel.rect(sx + 14, GROUND_Y - h - 14, 3, 4, C6)

        # near skyline
        for i in range(-3, 35):
            wx = i * 24
            sx = int(wx - self.cam_x * 0.52)
            h = 10 + ((i * 7) % 20)
            pyxel.rect(sx, GROUND_Y - h - 8, 16, h, C5)

    def draw_bunker_style(self, x, y, country_key):
        sx = int(x - self.cam_x)

        # base bunker
        pyxel.rect(sx, y - 34, 38, 34, C4)
        pyxel.rect(sx + 3, y - 29, 32, 6, C6)
        pyxel.rect(sx + 13, y - 16, 10, 16, C0)
        pyxel.line(sx + 4, y - 34, sx + 31, y - 34, C7)

        # country accent
        if country_key in ("switzerland", "japan", "poland", "austria"):
            pyxel.rect(sx + 5, y - 24, 28, 3, C8 if country_key != "japan" else C7)
        elif country_key in ("ukraine", "romania", "france"):
            draw_flag(country_key, sx + 10, y - 30, 14, 8, pyxel.frame_count * 0.15)
        elif country_key in ("usa", "germany", "belgium", "russia"):
            pyxel.rect(sx + 6, y - 24, 26, 3, C10 if country_key != "russia" else C12)
            pyxel.pset(sx + 11, y - 21, C7)
            pyxel.pset(sx + 25, y - 21, C7)
        else:
            pyxel.pset(sx + 8, y - 23, C10)
            pyxel.pset(sx + 28, y - 23, C10)

    def draw_bunkers(self):
        if not (self.is_campaign() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "CAMPAIGN ROTATION")):
            return

        bx = self.checkpoint_x - 28
        gy = self.terrain.ground_y_at(bx + 14, self.difficulty_value())
        sx = int(bx - self.cam_x)
        if -50 <= sx <= SCREEN_W + 50:
            self.draw_bunker_style(bx, gy, self.current_country()["key"])

        remaining = max(0, LEVEL_DISTANCE - self.distance) if self.is_campaign() else max(0, 1500 - self.distance)
        if remaining < 360:
            tx = self.player.x + remaining
            gy2 = self.terrain.ground_y_at(tx + 18, self.difficulty_value())
            sx2 = int(tx - self.cam_x)
            if -60 <= sx2 <= SCREEN_W + 60:
                self.draw_bunker_style(tx, gy2, self.current_country()["key"])
                pyxel.text(sx2 + 1, gy2 - 46, "NUKE BUNKER", C7)

    def draw_world(self):
        cam = self.cam_x + self.shake_x
        pyxel.camera(self.shake_x, self.shake_y)

        self.draw_background()
        self.terrain.draw(cam, self.difficulty_value())
        self.draw_bunkers()

        for p in self.powerups:
            p.draw(cam)

        for c in self.coins:
            c.draw(cam)

        for e in self.enemies:
            if hasattr(e, "draw"):
                e.draw(cam)

        for b in self.player_bullets:
            b.draw(cam)

        for b in self.enemy_bullets:
            b.draw(cam)

        for r in self.rockets:
            r.draw(cam)

        for b in self.bombs:
            b.draw(cam)

        for e in self.effects:
            e.draw(cam)

        for p in self.particles:
            p.draw(cam)

        for t in self.texts:
            t.draw(cam)

        self.player.draw(cam)

        pyxel.camera(0, 0)

    def draw_hud(self):
        pyxel.rect(0, 0, SCREEN_W, 24, C0)

        pyxel.text(8, 4, "HP", C7)
        for i in range(self.player.max_hp):
            col = C8 if i < self.player.hp else C5
            pyxel.rect(24 + i * 10, 4, 8, 6, col)

        pyxel.text(8, 14, "$ " + str(self.coins_total), C10)
        pyxel.text(70, 4, "WPN " + self.player.weapon, C7)
        pyxel.text(155, 4, "RKT " + str(self.player.rockets), C10)
        pyxel.text(210, 4, "BMB " + str(self.player.bombs), C8)
        pyxel.text(268, 4, str(self.player.score), C14)

        if self.player.combo > 1:
            pyxel.text(8, 176, "COMBO x" + str(self.player.combo), C10)

        y = 28
        if self.player.rapid_timer > 0:
            pyxel.text(8, y, "RAPID " + str(self.player.rapid_timer // 60) + "s", C9)
            y += 10
        if self.player.spread_timer > 0:
            pyxel.text(8, y, "SPREAD " + str(self.player.spread_timer // 60) + "s", C15)
            y += 10
        if self.player.overdrive_timer > 0:
            pyxel.text(8, y, "OVERDRIVE " + str(self.player.overdrive_timer // 60) + "s", C14)
            y += 10
        if self.player.invuln > 90:
            pyxel.text(8, y, "SHIELD " + str(self.player.invuln // 60) + "s", C12)

        if self.dev_mode:
            pyxel.text(250, 14, "DEV", C14)

        draw_flag(self.current_country()["key"], 292, 10, 16, 10, pyxel.frame_count * 0.2)

        if self.is_survival() or (self.is_multiplayer() and self.mp_mode_names[self.mp_mode_index] == "SURVIVAL ROTATION"):
            big = str(self.survival_frames // 60)
            pyxel.rect(132, 30, 58, 20, C0)
            pyxel.rectb(132, 30, 58, 20, C7)
            pyxel.text(142, 36, "TIME " + big, C14)
        elif self.is_campaign():
            bx = 100
            by = 190
            bw = 180
            fill = int(bw * self.distance / LEVEL_DISTANCE)
            pyxel.rect(bx, by, bw, 6, C5)
            pyxel.rect(bx, by, fill, 6, C13)
            pyxel.rectb(bx, by, bw, 6, C7)
            pyxel.text(100, 180, "LEVEL " + str(self.level) + "/" + str(CAMPAIGN_LEVELS), C7)
            pyxel.text(218, 180, str(int(max(0, LEVEL_DISTANCE - self.distance))) + "m", C7)
        else:
            pyxel.text(188, 180, "MP P" + str(self.mp_current_player + 1), C7)
            pyxel.text(188, 190, self.mp_mode_names[self.mp_mode_index][:16], C10)

    # --------------------------------------------------------
    # STORY TEXT
    # --------------------------------------------------------
    def story_lines_intro(self):
        name = self.current_country()["name"]
        return [
            f"{name} has been struck without warning.",
            "Cities burn. Sirens fail. Command breaks.",
            "Only one path remains.",
            "Reach the final bunker and save the nation by pressing the red button.",
        ]

    def story_lines_progress(self):
        name = self.current_country()["name"]
        lines = [
            f"REPORT: {name} still resists.",
            "Enemy missiles continue to rain from above.",
            "The war isn't finished.",
            "Push forward. Do not break.",
        ]
        idx = min(self.story_progress_index, 3)
        return [lines[idx], "The people still believe in you.", "Advance to the next front."]

    # --------------------------------------------------------
    # DRAW SCREENS
    # --------------------------------------------------------
    def draw_intro_country(self):
        self.draw_background()
        pyxel.rect(18, 14, 284, 172, C0)
        pyxel.rectb(18, 14, 284, 172, C7)

        pyxel.text(86, 26, "CHOOSE YOUR COUNTRY", C14)
        pyxel.text(56, 38, "YOUR FLAG, YOUR MAP, YOUR PRIDE, YOUR WAR", C7)

        key = self.current_country()["key"]
        draw_country_outline(key, 118, 52, scale=3, color=C7, fill=False)
        draw_flag(key, 150, 92, 20, 12, pyxel.frame_count * 0.2)
        pyxel.text(132, 110, self.current_country()["name"], C10)

        pyxel.text(60, 150, "LEFT/RIGHT = SELECT COUNTRY", C6)
        pyxel.text(60, 160, "ENTER = START", C6)

    def draw_intro_attack(self):
        pyxel.cls(C0)

        key = self.current_country()["key"]

        # burning cities
        for i in range(0, SCREEN_W, 18):
            h = 24 + (i * 7) % 36
            pyxel.rect(i, 132 - h, 12, h, C5)
            pyxel.pset(i + 3, 130 - h, C9)
            pyxel.pset(i + 7, 126 - h, C8)

        draw_country_outline(key, 88, 22, scale=4, color=C7, fill=False)

        if self.intro_flash > 0 and (pyxel.frame_count // 2) % 2 == 0:
            draw_country_outline(key, 88, 22, scale=4, color=C8, fill=True)

        draw_flag(key, 146, 40, 28, 18, pyxel.frame_count * 0.15)
        pyxel.line(144, 38, 144, 90, C4)

        for b in self.intro_bombs:
            pyxel.circ(int(b["x"]), int(b["y"]), 4, C8)
            pyxel.circ(int(b["x"]), int(b["y"]), 2, C10)

        for e in self.effects:
            e.draw(0)
        for p in self.particles:
            p.draw(0)

        lines = self.story_lines_intro()
        pyxel.text(38, 144, lines[0], C8)
        if self.intro_timer > 70:
            pyxel.text(38, 154, lines[1], C7)
        if self.intro_timer > 140:
            pyxel.text(38, 164, lines[2], C14)
        if self.intro_timer > 210:
            pyxel.text(38, 174, lines[3], C7)

        if self.intro_timer > 320:
            pyxel.text(100, 186, "PRESS ENTER", C7)

    def draw_menu(self):
        self.draw_background()

        draw_country_outline(
            self.current_country()["key"],
            190,
            34,
            scale=2,
            color=C8 if (pyxel.frame_count // 10) % 2 == 0 else C7,
            fill=False
        )

        # bombardment bg
        for i in range(3):
            bx = 190 + i * 22
            pyxel.circ(bx, 92 + (i % 2) * 8, 8 + (pyxel.frame_count // 8) % 3, C8)

        pyxel.rect(22, 18, 276, 170, C0)
        pyxel.rectb(22, 18, 276, 170, C7)

        pyxel.text(88, 28, "SUPREME LEADER SURVIVAL", C7)

        pyxel.text(40, 58, "COUNTRY", C7)
        draw_flag(self.current_country()["key"], 96, 56, 18, 12, pyxel.frame_count * 0.15)
        pyxel.text(122, 59, self.current_country()["name"], C7)

        pyxel.text(40, 78, "MODE", C7)
        for i, name in enumerate(self.mode_names):
            x = 82 + i * 74
            selected = i == self.mode_index
            if selected:
                pyxel.rect(x - 6, 74, 68, 16, C12)
                pyxel.text(x, 79, name[:9], C0)
            else:
                pyxel.rectb(x - 6, 74, 68, 16, C7)
                pyxel.text(x, 79, name[:9], C7)

        pyxel.text(40, 100, "SKIN", C7)
        skin_name = Player.SKINS[self.skin_index]["name"]
        pyxel.rect(82, 96, 128, 16, Player.SKINS[self.skin_index]["body"])
        pyxel.text(96, 101, skin_name, C0)

        if self.skin_index not in self.unlocked_skins:
            cost = SKIN_PRICES[self.skin_index]
            pyxel.text(216, 101, "LOCK " + str(cost) + "$", C8)

        if self.is_multiplayer():
            pyxel.text(40, 122, "MULTIPLAYER TYPE", C7)
            pyxel.text(160, 122, self.mp_mode_names[self.mp_mode_index][:16], C10)
            pyxel.text(40, 136, "PLAYERS", C7)
            pyxel.text(96, 136, str(self.mp_player_count), C14)
            pyxel.text(126, 136, "A/D CHANGE, W/S MODE", C6)

        pyxel.text(40, 156, "LEFT/RIGHT MODE", C6)
        pyxel.text(40, 166, "UP/DOWN SKIN, U BUY", C6)
        pyxel.text(180, 156, "ENTER START", C6)
        pyxel.text(180, 166, "COINS " + str(self.coins_total), C10)

    def draw_tutorial(self):
        self.draw_world()
        self.draw_hud()

        pyxel.rect(34, 34, 252, 74, C0)
        pyxel.rectb(34, 34, 252, 74, C7)
        pyxel.text(128, 42, "TUTORIAL", C14)

        steps = [
            "MOVE RIGHT WITH D",
            "PRESS W TO JUMP",
            "PRESS SPACE TO DASH",
            "SHOOT WITH ARROWS",
            "PRESS J FOR BOMB",
            "PRESS SHIFT TO CHANGE WEAPON, K TO SHOOT A ROCKET AND P TO PAUSE THE GAME. LETS GO!",
            
        ]
        pyxel.text(64, 64, steps[min(self.tutorial_step, len(steps) - 1)], C7)
        pyxel.text(86, 86, "PRESS ENTER TO SKIP", C6)

    def draw_story_update(self):
        pyxel.cls(C0)

        key = self.current_country()["key"]
        draw_country_outline(key, 92, 22, scale=4, color=C7, fill=False)
        draw_flag(key, 146, 44, 26, 16, pyxel.frame_count * 0.15)

        for i in range(0, SCREEN_W, 20):
            h = 18 + (i * 5) % 30
            pyxel.rect(i, 132 - h, 12, h, C5)

        for b in self.intro_bombs:
            pyxel.circ(int(b["x"]), int(b["y"]), 3, C8)

        for e in self.effects:
            e.draw(0)
        for p in self.particles:
            p.draw(0)

        lines = self.story_lines_progress()
        pyxel.text(40, 142, lines[0], C14)
        pyxel.text(40, 154, lines[1], C7)
        pyxel.text(40, 166, lines[2], C7)

        if self.intro_timer > 180:
            pyxel.text(104, 186, "PRESS ENTER", C7)

    def draw_paused(self):
        self.draw_world()
        self.draw_hud()

        pyxel.rect(92, 60, 136, 62, C0)
        pyxel.rectb(92, 60, 136, 62, C7)
        pyxel.text(144, 72, "PAUSED", C14)

        opts = ["RESUME", "LEAVE"]
        for i, opt in enumerate(opts):
            y = 90 + i * 12
            if i == self.pause_index:
                pyxel.rect(114, y - 2, 90, 10, C12)
                pyxel.text(146, y, opt, C0)
            else:
                pyxel.text(146, y, opt, C7)

    def draw_victory(self):
        pyxel.cls(C0)

        key = self.current_country()["key"]
        r = 18 + (pyxel.frame_count // 4) % 24
        pyxel.circb(160, 92, r, C14)
        pyxel.circb(160, 92, r + 8, C7)
        draw_country_outline(key, 110, 42, scale=4, color=C10, fill=False)
        draw_flag(key, 146, 78, 28, 18, pyxel.frame_count * 0.25)

        pyxel.text(120, 126, "VICTORY", C14)
        pyxel.text(54, 140, "YOU SAVED " + self.current_country()["name"], C7)
        pyxel.text(98, 154, "FINAL SCORE " + str(self.player.score), C10)
        pyxel.text(78, 170, "PRESS ENTER FOR MENU", C7)

    def draw_game_over(self):
        pyxel.cls(C0)
        pyxel.text(120, 64, "GAME OVER", C8)
        draw_flag(self.current_country()["key"], 146, 78, 26, 16, pyxel.frame_count * 0.1)
        pyxel.text(102, 104, "SURVIVED " + str(self.survival_frames // 60) + "s", C7)
        pyxel.text(104, 118, "BEST " + str(self.best_survival) + "s", C10)
        pyxel.text(102, 132, "SCORE " + str(self.player.score), C14)
        pyxel.text(98, 152, "PRESS ENTER FOR MENU", C7)

    def draw(self):
        if self.state == "intro_country":
            self.draw_intro_country()
            return

        if self.state == "intro_attack":
            self.draw_intro_attack()
            return

        if self.state == "menu":
            self.draw_menu()
            return

        if self.state == "tutorial":
            self.draw_tutorial()
            return

        if self.state == "story_update":
            self.draw_story_update()
            return

        if self.state == "paused":
            self.draw_paused()
            return

        if self.state == "victory":
            self.draw_victory()
            return

        if self.state == "game_over":
            self.draw_game_over()
            return

        self.draw_world()
        self.draw_hud()
        
        
        
        
    
    
    
    # ============================================================
# PART 3 / CREATIVE PATCH
# IMPORTANT:
# PASTE THIS DIRECTLY ABOVE THE FINAL LINE: Game()
# ============================================================

# ------------------------------------------------------------
# EXTRA COUNTRY FLAVOR
# ------------------------------------------------------------
COUNTRY_CITIES = {
    "usa": ["New York", "Washington", "Chicago", "Los Angeles"],
    "mexico": ["Mexico City", "Guadalajara", "Monterrey"],
    "china": ["Beijing", "Shanghai", "Shenzhen"],
    "germany": ["Berlin", "Hamburg", "Munich"],
    "serbia": ["Belgrad", "Novi Sad", "Nis"],
    "iran": ["Teheran", "Tabriz", "Isfahan"],
    "israel": ["Jerusalem", "Tel Aviv", "Haifa"],
    "antarctica": ["McMurdo", "Palmer", "South Pole"],
    "india": ["Delhi", "Mumbai", "Kolkata"],
    "italy": ["Rome", "Milan", "Naples"],
    "russia": ["Moscow", "Saint Petersburg", "Volgograd"],
    "poland": ["Warsaw", "Krakow", "Gdansk"],
    "france": ["Paris", "Lyon", "Marseille"],
    "nigeria": ["Lagos", "Abuja", "Kano"],
    "switzerland": ["Zurich", "Bern", "Geneva"],
    "canada": ["Ottawa", "Toronto", "Montreal"],
    "ukraine": ["Kyiv", "Lviv", "Odesa"],
    "turkiye": ["Ankara", "Istanbul", "Izmir"],
    "south_korea": ["Seoul", "Busan", "Incheon"],
    "north_korea": ["Pyongyang", "Sinuiju", "Wonsan"],
    "saudi": ["Riyadh", "Jeddah", "Mecca"],
    "belgium": ["Brussels", "Antwerp", "Ghent"],
    "austria": ["Vienna", "Graz", "Salzburg"],
    "romania": ["Bucharest", "Cluj", "Timisoara"],
    "japan": ["Tokyo", "Osaka", "Kyoto"],
    "uk": ["London", "Manchester", "Birmingham"],
    "spain": ["Madrid", "Barcelona", "Valencia"],
    "brazil": ["Brasilia", "Rio", "Sao Paulo"],
    "australia": ["Canberra", "Sydney", "Melbourne"],
    "south_africa": ["Pretoria", "Cape Town", "Johannesburg"],
}

COUNTRY_MOTTOS = {
    "usa": "Stand firm. Hold the line.",
    "mexico": "La bandera no cae.",
    "china": "Advance without breaking.",
    "germany": "Vorwaerts bis zum Bunker.",
    "serbia": "Hold the homeland.",
    "iran": "Defend the last road.",
    "israel": "Stand and survive.",
    "antarctica": "Even ice resists.",
    "india": "Carry the nation forward.",
    "italy": "To the bunker, no retreat.",
    "russia": "Steel against fire.",
    "poland": "Still standing.",
    "france": "Forward under the flag.",
    "nigeria": "Resist and rise.",
    "switzerland": "Stand strong in the mountains.",
    "canada": "Guard the north.",
    "ukraine": "Hold every meter.",
    "turkiye": "Advance under fire.",
    "south_korea": "Push through the storm.",
    "north_korea": "No step back.",
    "saudi": "Protect the homeland.",
    "belgium": "Hold the center.",
    "austria": "Through the passes.",
    "romania": "Protect the frontier.",
    "japan": "Endure and advance.",
    "uk": "Hold the kingdom.",
    "spain": "Push on together.",
    "brazil": "The flag still flies.",
    "australia": "Outlast the fire.",
    "south_africa": "Stand for the nation.",
}

# ------------------------------------------------------------
# PATCHED STORY TEXT
# ------------------------------------------------------------
def _patched_story_lines_intro(self):
    key = self.current_country()["key"]
    name = self.current_country()["name"]
    cities = COUNTRY_CITIES.get(key, ["Capital", "Harbor", "Frontier"])
    motto = COUNTRY_MOTTOS.get(key, "Carry the flag forward.")
    return [
        f"{name} is under attack.",
        f"{cities[0]} burns. {cities[min(1, len(cities)-1)]} is under missile fire.",
        f"{cities[min(2, len(cities)-1)]} has gone dark. Command is collapsing.",
        motto,
        "Reach the final bunker. Save what remains."
    ]


def _patched_story_lines_progress(self):
    key = self.current_country()["key"]
    name = self.current_country()["name"]
    cities = COUNTRY_CITIES.get(key, ["Capital", "Harbor", "Frontier"])
    motto = COUNTRY_MOTTOS.get(key, "Push forward.")
    phase = self.story_progress_index % 4

    if phase == 0:
        return [
            f"FIELD UPDATE: {name} still resists.",
            f"{cities[0]} reports heavy damage, but the flag still stands.",
            motto,
        ]
    elif phase == 1:
        return [
            f"FRONTLINE REPORT: enemy pressure rising near {cities[min(1, len(cities)-1)]}.",
            "The bunker route remains open for now.",
            "Advance immediately.",
        ]
    elif phase == 2:
        return [
            f"CIVIL DEFENSE MESSAGE: {cities[min(2, len(cities)-1)]} is evacuating.",
            "The nation watches your progress.",
            motto,
        ]
    else:
        return [
            f"COMMAND REPORT: {name} can still be saved.",
            "Missile rain continues. The final road is close.",
            "Do not stop now.",
        ]


Game.story_lines_intro = _patched_story_lines_intro
Game.story_lines_progress = _patched_story_lines_progress

# ------------------------------------------------------------
# PATCHED PAUSE UPDATE
# ------------------------------------------------------------
def _patched_update_pause(self):
    if pyxel.btnp(pyxel.KEY_UP):
        self.pause_index = (self.pause_index - 1) % 2
    if pyxel.btnp(pyxel.KEY_DOWN):
        self.pause_index = (self.pause_index + 1) % 2

    if pyxel.btnp(pyxel.KEY_P):
        self.state = "playing"
        return

    if pyxel.btnp(pyxel.KEY_RETURN):
        if self.pause_index == 0:
            self.state = "playing"
        else:
            self.state = "menu"
            stop_story_music()

Game.update_pause = _patched_update_pause

# ------------------------------------------------------------
# PATCHED PLAYER DRAW
# stronger dash flag trail + more visible commander silhouette
# ------------------------------------------------------------
_old_player_draw = Player.draw

def _patched_player_draw(self, cam_x):
    sx = int(self.x - cam_x)
    sy = int(self.y)

    if self.hurt_flash > 0 and (pyxel.frame_count // 2) % 2 == 0:
        return

    body = self.skin["body"]
    cape = self.skin["cape"]
    visor = self.skin["visor"]

    dark_body = max(1, body - 1)
    bright_body = min(15, body + 1)

    step = 0
    if self.on_ground and abs(self.vx) > 0.1:
        step = 1 if (self.walk_frame // 5) % 2 == 0 else -1

    if self.invuln > 0 and (pyxel.frame_count // 4) % 2 == 0:
        col = C7 if self.super_shield_timer > 0 else C12
        pyxel.circb(sx + 8, sy + 11, 13, col)
        pyxel.circb(sx + 8, sy + 11, 11, C6)

    # much stronger flag dash trail
    if self.dash_timer > 0:
        for i in range(1, 6):
            ox = sx - self.facing * i * 6
            trail_y = sy + (i % 2)
            pole_x = ox + 20 if self.facing == 1 else ox - 8
            pyxel.line(pole_x, trail_y + 3, pole_x, trail_y + 19, C5)
            draw_flag(
                self.country["key"],
                pole_x + (1 if self.facing == 1 else -14),
                trail_y + 4,
                12,
                7,
                pyxel.frame_count * 0.45 + i * 0.7
            )
            pyxel.rect(ox + 4, trail_y + 7, 8, 10, C6)

    cape_wave = 1 if (pyxel.frame_count // 4) % 2 == 0 else -1
    pyxel.tri(sx + 4, sy + 8, sx - 4, sy + 21 + cape_wave, sx + 5, sy + 20, cape)
    pyxel.tri(sx + 12, sy + 8, sx + 20, sy + 21 - cape_wave, sx + 11, sy + 20, cape)
    pyxel.line(sx + 6, sy + 8, sx + 3, sy + 20, dark_body)
    pyxel.line(sx + 10, sy + 8, sx + 13, sy + 20, dark_body)

    pole_x = sx + 20 if self.facing == 1 else sx - 8
    pyxel.line(pole_x, sy + 1, pole_x, sy + 23, C4)
    if self.facing == 1:
        draw_flag(self.country["key"], pole_x + 1, sy + 3, 15, 9, pyxel.frame_count * 0.28)
    else:
        draw_flag(self.country["key"], pole_x - 16, sy + 3, 15, 9, pyxel.frame_count * 0.28)

    # shadow
    pyxel.rect(sx + 2, sy - 1, 12, 8, C0)
    pyxel.rect(sx + 1, sy + 5, 14, 14, C0)

    # helmet
    pyxel.rect(sx + 4, sy, 8, 5, C11)
    pyxel.rect(sx + 3, sy + 1, 10, 3, C11)
    pyxel.rect(sx + 3, sy - 2, 10, 2, visor)
    pyxel.pset(sx + 7, sy - 3, C14)
    pyxel.pset(sx + 8, sy - 3, C14)

    if self.facing == 1:
        pyxel.rect(sx + 7, sy + 2, 4, 2, visor)
        pyxel.pset(sx + 10, sy + 2, C7)
    else:
        pyxel.rect(sx + 5, sy + 2, 4, 2, visor)
        pyxel.pset(sx + 5, sy + 2, C7)

    pyxel.rect(sx + 1, sy + 6, 4, 4, bright_body)
    pyxel.rect(sx + 11, sy + 6, 4, 4, bright_body)
    pyxel.line(sx + 1, sy + 6, sx + 4, sy + 5, C7)
    pyxel.line(sx + 11, sy + 6, sx + 14, sy + 5, C7)

    pyxel.rect(sx + 4, sy + 6, 8, 11, body)
    pyxel.rect(sx + 5, sy + 7, 6, 9, bright_body)
    pyxel.line(sx + 8, sy + 6, sx + 8, sy + 16, dark_body)

    pyxel.line(sx + 4, sy + 7, sx + 11, sy + 15, C8)
    pyxel.line(sx + 5, sy + 7, sx + 12, sy + 14, C14)
    pyxel.pset(sx + 10, sy + 10, C14)
    pyxel.pset(sx + 10, sy + 11, C10)
    pyxel.pset(sx + 6, sy + 11, C7)

    pyxel.rect(sx + 4, sy + 16, 8, 2, C4)
    pyxel.pset(sx + 8, sy + 16, C14)

    if self.shoot_anim > 0:
        if self.facing == 1:
            pyxel.line(sx + 12, sy + 9, sx + 16, sy + 9, bright_body)
            pyxel.line(sx + 4, sy + 9, sx + 1, sy + 12, dark_body)
        else:
            pyxel.line(sx + 4, sy + 9, sx, sy + 9, bright_body)
            pyxel.line(sx + 12, sy + 9, sx + 15, sy + 12, dark_body)
    else:
        pyxel.line(sx + 4, sy + 9, sx + 1, sy + 13, dark_body)
        pyxel.line(sx + 12, sy + 9, sx + 15, sy + 13, dark_body)

    pyxel.rect(sx + 5, sy + 18, 3, 4, dark_body)
    pyxel.rect(sx + 9, sy + 18, 3, 4, dark_body)
    pyxel.line(sx + 6, sy + 22, sx + 5 + step, sy + 25, C5)
    pyxel.line(sx + 10, sy + 22, sx + 11 - step, sy + 25, C5)
    pyxel.rect(sx + 3 + step, sy + 24, 4, 2, C0)
    pyxel.rect(sx + 9 - step, sy + 24, 4, 2, C0)

    if self.facing == 1:
        pyxel.rect(sx + 14, sy + 8, 4, 2, C5)
        pyxel.rect(sx + 17, sy + 8, 3, 1, C7)
    else:
        pyxel.rect(sx - 2, sy + 8, 4, 2, C5)
        pyxel.rect(sx - 3, sy + 8, 3, 1, C7)

    if self.shoot_anim > 0:
        mx = sx + 20 if self.facing == 1 else sx - 3
        my = sy + 9
        pyxel.circ(mx, my, 3, C9)
        pyxel.circ(mx, my, 2, C10)
        pyxel.pset(mx, my, C7)

Player.draw = _patched_player_draw

# ------------------------------------------------------------
# PATCHED WORLD DRAW
# add multiplayer scoreboard
# ------------------------------------------------------------
_old_draw_hud = Game.draw_hud

def _patched_draw_hud(self):
    _old_draw_hud(self)

    if self.is_multiplayer():
        pyxel.rect(212, 28, 102, 10 + self.mp_player_count * 8, C0)
        pyxel.rectb(212, 28, 102, 10 + self.mp_player_count * 8, C7)
        pyxel.text(218, 32, "MULTIPLAYER", C14)
        for i in range(self.mp_player_count):
            y = 42 + i * 8
            col = C10 if i == self.mp_current_player else C7
            score = self.mp_scores[i] if i < len(self.mp_scores) else 0
            pyxel.text(218, y, f"P{i+1}: {score}", col)

Game.draw_hud = _patched_draw_hud

# ------------------------------------------------------------
# PATCHED DRAW INTRO ATTACK
# longer / more cinematic city attack
# ------------------------------------------------------------
def _patched_draw_intro_attack(self):
    pyxel.cls(C0)
    key = self.current_country()["key"]
    cities = COUNTRY_CITIES.get(key, ["Capital", "Harbor", "Frontier"])

    # city skyline
    for i in range(0, SCREEN_W, 14):
        h = 18 + (i * 9) % 42
        pyxel.rect(i, 138 - h, 9, h, C5)
        pyxel.pset(i + 2, 136 - h, C8)
        pyxel.pset(i + 5, 130 - h, C9)

    # outline of country
    draw_country_outline(key, 88, 18, scale=4, color=C7, fill=False)

    if self.intro_flash > 0 and (pyxel.frame_count // 2) % 2 == 0:
        draw_country_outline(key, 88, 18, scale=4, color=C8, fill=True)

    draw_flag(key, 146, 38, 28, 18, pyxel.frame_count * 0.15)
    pyxel.line(144, 36, 144, 90, C4)

    for b in self.intro_bombs:
        pyxel.circ(int(b["x"]), int(b["y"]), 4, C8)
        pyxel.circ(int(b["x"]), int(b["y"]), 2, C10)

    for e in self.effects:
        e.draw(0)
    for p in self.particles:
        p.draw(0)

    pyxel.text(26, 144, f"{cities[0]} is burning.", C8)
    if self.intro_timer > 60:
        pyxel.text(26, 154, f"{cities[min(1, len(cities)-1)]} is under missile fire.", C7)
    if self.intro_timer > 130:
        pyxel.text(26, 164, f"{cities[min(2, len(cities)-1)]} has gone dark.", C7)
    if self.intro_timer > 200:
        pyxel.text(26, 174, COUNTRY_MOTTOS.get(key, "Carry the flag forward."), C14)

    if self.intro_timer > 320:
        pyxel.text(100, 186, "PRESS ENTER", C7)

Game.draw_intro_attack = _patched_draw_intro_attack

# ------------------------------------------------------------
# PATCHED STORY UPDATE DRAW
# ------------------------------------------------------------
def _patched_draw_story_update(self):
    pyxel.cls(C0)
    key = self.current_country()["key"]

    draw_country_outline(key, 92, 22, scale=4, color=C7, fill=False)
    draw_flag(key, 146, 44, 26, 16, pyxel.frame_count * 0.15)

    for i in range(0, SCREEN_W, 20):
        h = 18 + (i * 5) % 30
        pyxel.rect(i, 132 - h, 12, h, C5)

    for b in self.intro_bombs:
        pyxel.circ(int(b["x"]), int(b["y"]), 3, C8)

    for e in self.effects:
        e.draw(0)
    for p in self.particles:
        p.draw(0)

    lines = self.story_lines_progress()
    pyxel.text(28, 140, lines[0], C14)
    pyxel.text(28, 154, lines[1], C7)
    pyxel.text(28, 168, lines[2], C7)

    if self.intro_timer > 180:
        pyxel.text(104, 186, "PRESS ENTER", C7)

Game.draw_story_update = _patched_draw_story_update

# ------------------------------------------------------------
# PATCHED VICTORY DRAW
# ------------------------------------------------------------
def _patched_draw_victory(self):
    pyxel.cls(C0)
    key = self.current_country()["key"]
    r = 18 + (pyxel.frame_count // 4) % 24

    pyxel.circb(160, 92, r, C14)
    pyxel.circb(160, 92, r + 8, C7)
    pyxel.circb(160, 92, r + 16, C10)

    draw_country_outline(key, 108, 38, scale=4, color=C10, fill=False)
    draw_flag(key, 146, 76, 28, 18, pyxel.frame_count * 0.25)

    pyxel.text(118, 126, "VICTORY", C14)
    pyxel.text(44, 140, "THE FLAG STILL FLIES OVER " + self.current_country()["name"], C7)
    pyxel.text(98, 154, "FINAL SCORE " + str(self.player.score), C10)
    pyxel.text(78, 170, "PRESS ENTER FOR MENU", C7)

Game.draw_victory = _patched_draw_victory


# ============================================================
# END PATCH
# ============================================================


Game()