"""Micro Racers — a two-player, top-down pixel racing game built with Pyxel.

The track is a tilemap (res5.pyxres). Terrain, checkpoints and hazards are read
straight from the rendered pixels, so the track is drawn every frame in update()
(for those reads) and again in draw() (for display).

Structure:
    constants            tuning values, palette colours, sound ids
    sound definitions    define_sounds()
    math + track helpers  pure helpers that read the framebuffer
    ExhaustParticle      smoke puff behind a moving car
    Spectator            cheering crowd member (can be run over)
    Car                  one player's car: physics, checkpoints, drawing
    Game                 top-level state machine, HUD and the Pyxel loop
"""

import pyxel
from pyxel import *
from math import sin, cos, pi, sqrt
from random import seed, randint, random, choice
from titlescreen import draw_title_art


# ===========================================================================
# Constants
# ===========================================================================
SCREEN_W = 1024
SCREEN_H = 512
HUD_H = 20

# Driving physics
TURN_SPEED = 0.06
ACCEL = 0.065
BRAKE = 0.08
MAX_SPEED_TRACK = 6.2
MAX_SPEED_CURB = 3.4
MAX_SPEED_GRASS = 1.8
FRICTION_ROAD = 0.9895
FRICTION_CURB = 0.9838
FRICTION_GRASS = 0.9801
START_ANGLE = pi / 2          # cars start pointing right (+x)

# Cars touching each other are pushed apart by this much
CAR_RADIUS = 12
CAR_REPEL_FORCE = 3.0

# Race rules
LAPS_TO_WIN = 3
NUM_CHECKPOINTS = 3

# Track tilemap palette colours
COL_GRASS = 11
COL_CURB_RED = 8
COL_CURB_WHITE = 7
COL_ROAD_A = 5
COL_ROAD_B = 13
COL_LANE = 7
COL_TREE = 3
COL_LAKE_DEEP = 5             # dark blue deep water (cars sink here)
COL_LAKE_ICE = 6             # light blue ice (driveable)
COL_CHECKPOINT = 14
COL_CHECKPOINT_EDGE = 15      # only counts when it borders a COL_CHECKPOINT pixel

# Skidmarks + exhaust
SKID_TURN_THRESHOLD = 0.025
SKID_MIN_SPEED = 2.0
SKID_SUSTAIN_FRAMES = 8
MAX_SKIDMARKS = 4000
MAX_EXHAUST_PARTICLES = 60

# Lake death (frames): car shrinks while sinking, then only bubbles remain
SINK_DURATION = 120
SINK_BUBBLE_END = 180

SPECTATOR_HIT_RADIUS = 14
SPECTATOR_HOTSPOTS = [
    (525, 80, 80), (525, 110, 70), (155, 430, 80), (1005, 470, 70),
    (475, 280, 80), (300, 200, 60), (800, 300, 60), (700, 100, 60),
]


# ===========================================================================
# Sound
# ===========================================================================
SND_ENGINE_IDLE = 0
SND_ENGINE_LOW = 1
SND_ENGINE_MID = 2
SND_SPLASH = 4
SND_CHECKPOINT = 5
SND_CRASH = 6
SND_COUNTDOWN_BEEP = 7
SND_COUNTDOWN_GO = 8
SND_VICTORY = 9
SND_LAP = 10


def define_sounds():
    sound(SND_ENGINE_IDLE).set("c1 c1 d1 c1 c1 d1 c1 c1", "p", "2 2 2 2 2 2 2 2", "n", 5)
    sound(SND_ENGINE_LOW).set("e1 f1 e1 f1 e1 f1 e1 f1", "s", "3 3 3 3 3 3 3 3", "n", 4)
    sound(SND_ENGINE_MID).set("c2 d2 c2 d2 c2 d2 c2 d2", "s", "4 4 5 4 4 5 4 4", "n", 3)
    sound(SND_SPLASH).set(
        "c3 g2 e2 c2 g1 e1 c1 c0", "n n t t t t t t",
        "7 6 6 5 5 4 3 2", "n n n s s f f f", 3)
    sound(SND_CHECKPOINT).set("e3 g3", "s", "5 5", "n", 6)
    sound(SND_CRASH).set(
        "g2 e2 c2 g1 e1 c1 c0 c0", "s s s s s s t t",
        "7 7 6 6 5 4 4 2", "n s s s s s f f", 2)
    sound(SND_COUNTDOWN_BEEP).set(
        "e3 r e3 r r r r r", "s s p s s s s s",
        "7 0 5 0 0 0 0 0", "n n f n n n n n", 3)
    sound(SND_COUNTDOWN_GO).set(
        "c2 g2 c3 e3 g3 c4 r c4", "s s s s s s s p",
        "5 5 6 6 7 7 0 6", "n n n n n v n f", 3)
    sound(SND_VICTORY).set(
        "c3 r e3 r g3 g3 c4 c4 c4 r e4 e4 g4 g4 g4 r",
        "s s s s s s s s s s s s s s s s",
        "5 0 6 0 6 7 7 7 7 0 7 7 7 7 6 0",
        "n n n n n n v v n n v v v v f n", 6)
    sound(SND_LAP).set(
        "g3 r c4 r e4 r g4 r", "p p p p p p p p",
        "6 0 6 0 7 0 7 0", "n n n n n n v n", 3)


# ===========================================================================
# Math helpers
# ===========================================================================
def clamp(value, lo, hi):
    return max(lo, min(value, hi))


def ddist(x1, y1, x2, y2):
    return sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)


def forward_vec(angle):
    """Unit vector the car points along."""
    return sin(angle), -cos(angle)


def side_vec(angle):
    """Unit vector perpendicular to the car's facing."""
    fx, fy = forward_vec(angle)
    return -fy, fx


def angle_diff(a, b):
    """Signed difference between two angles, wrapped to [-pi, pi]."""
    d = a - b
    while d > pi:
        d -= 2 * pi
    while d < -pi:
        d += 2 * pi
    return d


# ===========================================================================
# Track / terrain — read straight from the rendered framebuffer
# ===========================================================================
def render_track():
    bltm(0, 0, 0, 0, 0, SCREEN_W, SCREEN_H, 0)


def color_at(x, y):
    return pget(clamp(int(x), 0, SCREEN_W - 1), clamp(int(y), HUD_H, SCREEN_H - 1))


def is_white_curb(x, y):
    """A white pixel is a curb (not road) only when a red curb pixel is nearby."""
    return any(color_at(x + ox, y + oy) == COL_CURB_RED
               for ox, oy in ((1, 0), (-1, 0), (0, 1), (0, -1),
                              (2, 0), (-2, 0), (0, 2), (0, -2)))


def is_tree_pixel(x, y):
    return color_at(x, y) == COL_TREE


def is_deep_water(x, y):
    """Dark-blue lake water, confirmed by neighbouring water/ice/grass pixels."""
    if color_at(x, y) != COL_LAKE_DEEP:
        return False
    near = sum(1 for ox, oy in ((3, 0), (-3, 0), (0, 3), (0, -3))
               if color_at(x + ox, y + oy) in (COL_LAKE_DEEP, COL_LAKE_ICE, COL_GRASS))
    return near >= 2


def terrain_at(x, y):
    """Classify a pixel as 'grass', 'curb' or 'road' (the default)."""
    c = color_at(x, y)
    if c == COL_GRASS or c == COL_TREE:
        return "grass"
    if c == COL_CURB_RED:
        return "curb"
    if c == COL_CURB_WHITE:
        return "curb" if is_white_curb(x, y) else "road"
    return "road"


# ===========================================================================
# ExhaustParticle — small white smoke puff
# ===========================================================================
class ExhaustParticle:
    def __init__(self, x, y, vx, vy):
        self.x, self.y = x, y
        self.vx, self.vy = vx, vy
        self.life = 0
        self.max_life = randint(8, 16)
        self.size = 0

    def update(self):
        """Advance one frame; return False once the puff should be removed."""
        self.life += 1
        self.x += self.vx
        self.y += self.vy
        self.vy -= 0.01            # smoke drifts upward
        self.vx *= 0.95
        self.vy *= 0.95
        frac = self.life / self.max_life
        if frac < 0.3:
            self.size = 1 + int(frac / 0.3 * 2)               # grow
        else:
            self.size = max(0, 3 - int((frac - 0.3) / 0.7 * 3))  # then shrink
        return self.life < self.max_life

    def draw(self):
        if self.size > 0:
            circ(int(self.x), int(self.y), self.size, 7)


# ===========================================================================
# Spectator — cheering crowd member that can be run over
# ===========================================================================
class Spectator:
    def __init__(self, x, y):
        self.x = float(x)
        self.y = float(y)
        self.body_col = choice([8, 12, 2, 9, 14, 4])
        self.skin_col = choice([15, 9])
        self.hat = random() > 0.55
        self.hat_col = choice([10, 8, 12, 7, 1])
        self.anim = randint(0, 59)
        self.wave_phase = random() * pi * 2
        self.cheer_timer = 0
        self.dead = False
        self.blood = []

    def update(self, car_positions):
        """Returns True only on the frame this spectator is run over."""
        self.anim += 1
        self.wave_phase += 0.08
        if self.dead:
            return False
        for cx, cy in car_positions:
            if ddist(self.x, self.y, cx, cy) < SPECTATOR_HIT_RADIUS:
                self._die(cx, cy)
                return True
        if any(ddist(self.x, self.y, cx, cy) < 60 for cx, cy in car_positions):
            self.cheer_timer = 20
        elif self.cheer_timer > 0:
            self.cheer_timer -= 1
        return False

    def _die(self, cx, cy):
        self.dead = True
        dx, dy = self.x - cx, self.y - cy
        d = sqrt(dx * dx + dy * dy) + 0.01
        self.x += dx / d * 5          # knocked away from the car
        self.y += dy / d * 5
        for _ in range(randint(8, 14)):
            self.blood.append((self.x + randint(-12, 12),
                               self.y + randint(-8, 10), randint(0, 2)))

    def draw(self):
        ix, iy = int(self.x), int(self.y)
        if self.dead:
            self._draw_dead(ix, iy)
        else:
            self._draw_alive(ix, iy)

    def _draw_alive(self, ix, iy):
        y = iy + (int(sin(self.anim * 0.45) * 3.0) if self.cheer_timer > 0 else 0)
        rect(ix - 1, y - 1, 3, 4, self.body_col)   # body
        rect(ix - 1, y - 4, 3, 3, self.skin_col)   # head
        pset(ix - 1, y - 3, 0)                      # eyes
        pset(ix + 1, y - 3, 0)
        pset(ix - 1, y + 3, 1)                      # feet
        pset(ix + 1, y + 3, 1)
        if self.hat:
            rect(ix - 2, y - 5, 5, 1, self.hat_col)
            rect(ix - 1, y - 6, 3, 1, self.hat_col)
        if self.cheer_timer > 0:                    # waving arms
            wave = int(sin(self.wave_phase) * 1.5)
            pset(ix - 2, y - 3 + wave, self.skin_col)
            pset(ix - 2, y - 4 + wave, self.skin_col)
            pset(ix + 2, y - 3 - wave, self.skin_col)
            pset(ix + 2, y - 4 - wave, self.skin_col)
        else:                                        # arms at side
            pset(ix - 2, y, self.skin_col)
            pset(ix + 2, y, self.skin_col)

    def _draw_dead(self, ix, iy):
        for bx, by, br in self.blood:
            circ(int(bx), int(by), br, 8)
            if br > 0:
                pset(int(bx), int(by), 2)
        rect(ix - 3, iy, 7, 2, self.body_col)        # flattened body
        rect(ix - 4, iy, 2, 2, self.skin_col)        # head
        pset(ix - 4, iy, 0)
        pset(ix - 3, iy + 1, 0)
        for ox in range(-2, 4):
            pset(ix + ox, iy + 2, 8)
        pset(ix, iy + 3, 8)
        pset(ix + 1, iy + 3, 8)


def can_place_spectator(x, y, crowd):
    """True if (x, y) is open grass close to the track and not crowding others."""
    if x < 6 or x > SCREEN_W - 6 or y < HUD_H + 6 or y > SCREEN_H - 6:
        return False
    if terrain_at(x, y) != "grass":
        return False
    if is_tree_pixel(x, y) or is_deep_water(x, y) or color_at(x, y) == COL_LAKE_ICE:
        return False
    # Must be near the track: some non-grass within ~20px.
    near_track = any(
        terrain_at(x + ox, y + oy) != "grass"
        for dist in (8, 12, 16, 20)
        for ox, oy in ((dist, 0), (-dist, 0), (0, dist), (0, -dist),
                       (dist, dist), (-dist, dist), (dist, -dist), (-dist, -dist)))
    if not near_track:
        return False
    # Immediate footprint must be clear grass.
    if any(terrain_at(x + ox, y + oy) != "grass"
           for ox, oy in ((0, 0), (3, 0), (-3, 0), (0, 3), (0, -3))):
        return False
    return all(ddist(x, y, s.x, s.y) >= 8 for s in crowd)


def generate_spectators():
    """Place ~300 spectators (60% clustered in hotspots) on grass near the track."""
    seed(99)
    crowd = []
    target = 300
    for limit in (int(target * 0.6), target):
        attempts = 0
        while len(crowd) < limit and attempts < 40000:
            attempts += 1
            if limit < target:                       # hotspot pass
                hx, hy, hr = choice(SPECTATOR_HOTSPOTS)
                a, d = random() * pi * 2, random() * hr
                x, y = int(hx + cos(a) * d), int(hy + sin(a) * d)
            else:                                     # fill the rest randomly
                x = randint(6, SCREEN_W - 6)
                y = randint(HUD_H + 6, SCREEN_H - 6)
            if can_place_spectator(x, y, crowd):
                crowd.append(Spectator(x, y))
    return crowd


# ===========================================================================
# Car — one player's vehicle: physics, checkpoints, exhaust and drawing
# ===========================================================================
class Car:
    # Body panels in local space: (corner points, "body" | "accent" | colour).
    PANELS = [
        ([(-4, -7), (4, -7), (4, 7), (-4, 7)], "body"),     # chassis
        ([(-3, -9), (3, -9), (4, -7), (-4, -7)], "body"),   # nose
        ([(-4, 7), (4, 7), (3, 9), (-3, 9)], "accent"),     # rear wing
        ([(-2, -5), (2, -5), (2, -3), (-2, -3)], 1),        # windshield
    ]
    WHEELS = [(-5, -5), (5, -5), (-5, 5), (5, 5)]

    def __init__(self, number, body_col, accent_col, keys, channel, start):
        self.number = number                 # 1 or 2
        self.body_col = body_col
        self.accent_col = accent_col
        self.left, self.right, self.up, self.down = keys
        self.channel = channel               # engine-sound audio channel
        self.start_x, self.start_y = start
        self.exhaust = []
        self.reset()

    def reset(self):
        self.x = float(self.start_x)
        self.y = float(self.start_y)
        self.angle = START_ANGLE
        self.prev_angle = START_ANGLE
        self.speed = 0.0
        self.lap = 0
        self.cp_passed = [False, False, False]
        self.was_on_cp = False
        self.on_finish_last = True           # cars start on the finish line
        self.steer_frames = 0
        self.smoke_tick = 0
        self.engine_snd = SND_ENGINE_IDLE
        self.dead = False
        self.dead_type = ""                  # "fire" or "lake"
        self.dead_timer = 0
        self.exhaust.clear()

    # --- physics -----------------------------------------------------------
    def drive(self):
        self.prev_angle = self.angle
        terrain = self._terrain_under()

        # Steering tightens at low speed and widens at high speed.
        turn = TURN_SPEED * (1.0 - min(abs(self.speed) / MAX_SPEED_TRACK, 1.0) * 0.7)
        if btn(self.left):
            self.angle -= turn
        if btn(self.right):
            self.angle += turn

        accel, brake = ACCEL, BRAKE
        if terrain == "grass":
            accel, brake = accel * 0.55, brake * 0.85
        elif terrain == "curb":
            accel *= 0.85
        if btn(self.up):
            self.speed += accel
        if btn(self.down):
            self.speed -= brake

        max_spd, fric = {
            "road": (MAX_SPEED_TRACK, FRICTION_ROAD),
            "curb": (MAX_SPEED_CURB, FRICTION_CURB),
            "grass": (MAX_SPEED_GRASS, FRICTION_GRASS),
        }[terrain]
        self.speed = max(-2.0, min(self.speed, max_spd)) * fric

        fx, fy = forward_vec(self.angle)
        self.x += fx * self.speed
        self.y += fy * self.speed
        self.clamp_position()

    def clamp_position(self):
        self.x = clamp(self.x, 8, SCREEN_W - 9)
        self.y = clamp(self.y, HUD_H + 8, SCREEN_H - 9)

    def _footprint(self, *offsets):
        """Map (forward, side) offsets to world points around the car."""
        fx, fy = forward_vec(self.angle)
        sx, sy = side_vec(self.angle)
        return [(self.x + fx * f + sx * s, self.y + fy * f + sy * s)
                for f, s in offsets]

    def _terrain_under(self):
        pts = self._footprint((0, 0), (7, 0), (-7, 0), (0, 4), (0, -4))
        counts = {"road": 0, "curb": 0, "grass": 0}
        for px, py in pts:
            counts[terrain_at(px, py)] += 1
        if counts["grass"] >= 2:
            return "grass"
        if counts["curb"] >= 2:
            return "curb"
        return "road"

    # --- hazards -----------------------------------------------------------
    def hazard(self):
        """Return 'tree', 'lake' or None for the car's current position."""
        trees = self._footprint(
            (10, 0), (10, 3), (10, -3), (7, 5), (7, -5),
            (0, 6), (0, -6), (3, 6), (3, -6), (-3, 6), (-3, -6),
            (-10, 0), (-10, 3), (-10, -3), (-7, 5), (-7, -5))
        if sum(1 for px, py in trees if is_tree_pixel(px, py)) >= 2:
            return "tree"
        lake = self._footprint(
            (0, 0), (7, 0), (-7, 0), (0, 4), (0, -4),
            (7, 4), (7, -4), (-7, 4), (-7, -4))
        if all(is_deep_water(px, py) for px, py in lake):
            return "lake"
        return None

    def kill(self, hazard_type):
        self.dead = True
        self.dead_type = "fire" if hazard_type == "tree" else "lake"
        self.dead_timer = 0
        self.speed = 0.0
        stop(self.channel)
        play(2, SND_CRASH if self.dead_type == "fire" else SND_SPLASH)

    # --- checkpoints -------------------------------------------------------
    def _on_checkpoint(self):
        for px, py in self._footprint(
                (0, 0), (5, 0), (-5, 0), (9, 0), (-9, 0), (0, 4), (0, -4),
                (5, 4), (5, -4), (-5, 4), (-5, -4)):
            x = clamp(int(px), 0, SCREEN_W - 1)
            y = clamp(int(py), HUD_H, SCREEN_H - 1)
            c = pget(x, y)
            if c == COL_CHECKPOINT:
                return True
            if c == COL_CHECKPOINT_EDGE and self._checkpoint_neighbor(x, y):
                return True
        return False

    @staticmethod
    def _checkpoint_neighbor(x, y):
        return any(
            pget(clamp(x + dx, 0, SCREEN_W - 1), clamp(y + dy, HUD_H, SCREEN_H - 1))
            == COL_CHECKPOINT
            for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1),
                           (2, 0), (-2, 0), (0, 2), (0, -2)))

    def _checkpoint_zone(self):
        """0 = finish line, 1/2/3 = the three checkpoints, by track region."""
        if self.y < 130:
            return 0
        if self.x < 300:
            return 1
        if self.x > 800:
            return 2
        return 3

    def update_checkpoints(self):
        """Track lap progress. Returns 'cp', 'lap' or None when crossing a marker."""
        on_cp = self._on_checkpoint()
        event = None
        if on_cp and not self.was_on_cp:
            zone = self._checkpoint_zone()
            if zone == 0:
                if not self.on_finish_last and all(self.cp_passed):
                    self.lap += 1
                    self.cp_passed = [False, False, False]
                    event = "lap"
                self.on_finish_last = True
            else:
                if not self.cp_passed[zone - 1]:
                    self.cp_passed[zone - 1] = True
                    event = "cp"
                self.on_finish_last = False
        elif on_cp:
            self.on_finish_last = self._checkpoint_zone() == 0
        else:
            self.on_finish_last = False
        self.was_on_cp = on_cp
        return event

    # --- effects -----------------------------------------------------------
    def update_engine(self):
        spd = abs(self.speed)
        target = (SND_ENGINE_IDLE if spd < 1 else
                  SND_ENGINE_LOW if spd < 3 else SND_ENGINE_MID)
        if target != self.engine_snd or play_pos(self.channel) is None:
            play(self.channel, target, loop=True)
        self.engine_snd = target

    def spawn_exhaust(self):
        self.smoke_tick ^= 1
        if self.dead or abs(self.speed) < 0.3:
            return
        fx, fy = forward_vec(self.angle)
        emit = 0.3 + abs(self.speed) * 0.12
        vx = -fx * emit + (random() - 0.5) * 0.3
        vy = -fy * emit + (random() - 0.5) * 0.3
        # Fast cars puff every frame; medium cars every other frame.
        if abs(self.speed) > 3.0 or (self.smoke_tick == 0 and abs(self.speed) > 1.5):
            self.exhaust.append(ExhaustParticle(self.x - fx * 13, self.y - fy * 13, vx, vy))
        while len(self.exhaust) > MAX_EXHAUST_PARTICLES:
            self.exhaust.pop(0)

    def update_exhaust(self):
        self.exhaust = [p for p in self.exhaust if p.update()]

    def draw_exhaust(self):
        for p in self.exhaust:
            p.draw()

    def skidding(self):
        """Advance the skid timer; True once the car has been sliding long enough."""
        turn_rate = abs(angle_diff(self.angle, self.prev_angle))
        if turn_rate > SKID_TURN_THRESHOLD and abs(self.speed) > SKID_MIN_SPEED:
            self.steer_frames += 1
        else:
            self.steer_frames = 0
        return self.steer_frames >= SKID_SUSTAIN_FRAMES

    # --- drawing -----------------------------------------------------------
    def draw(self):
        if self.dead_type == "lake":
            self._draw_sinking()
            return
        self._draw_chassis(1.0, details=True)
        if self.dead_type == "fire":
            self._draw_fire()

    def _draw_chassis(self, scale, details=False):
        ca, sa = cos(self.angle), sin(self.angle)

        def rot(px, py):
            px, py = px * scale, py * scale
            return self.x + px * ca - py * sa, self.y + px * sa + py * ca

        for wx, wy in self.WHEELS:
            self._quad(rot, [(wx - 2, wy - 2), (wx + 2, wy - 2),
                             (wx + 2, wy + 2), (wx - 2, wy + 2)], 1)
        for pts, key in self.PANELS:
            col = (self.body_col if key == "body" else
                   self.accent_col if key == "accent" else key)
            self._quad(rot, pts, col)
        if not details:
            return

        for lx in (-2, 2):                                  # headlights
            x, y = rot(lx, -9)
            pset(int(x), int(y), 10)
        for lx in (-3, 3):                                  # taillights
            x, y = rot(lx, 9)
            pset(int(x), int(y), 8)
        self._quad(rot, [(-2, 10), (2, 10), (2, 12), (-2, 12)], 0)  # exhaust pipe
        x0, y0 = rot(-2, 10)
        x1, y1 = rot(2, 10)
        line(int(x0), int(y0), int(x1), int(y1), 5)
        for sx in (-2, 2):                                  # racing stripes
            a, b = rot(sx, -8), rot(sx, 8)
            line(int(a[0]), int(a[1]), int(b[0]), int(b[1]), 7)

    @staticmethod
    def _quad(rot, pts, col):
        (ax, ay), (bx, by), (cx, cy), (dx, dy) = [rot(px, py) for px, py in pts]
        tri(ax, ay, bx, by, cx, cy, col)
        tri(ax, ay, cx, cy, dx, dy, col)

    def _draw_sinking(self):
        t = self.dead_timer
        if t >= SINK_DURATION:                  # car gone, only bubbles linger
            if t < SINK_BUBBLE_END:
                self._draw_bubbles(rise=(t - SINK_DURATION) * 0.25, radius=2)
            return
        scale = 1.0 - t / SINK_DURATION
        if scale > 0.05:
            self._draw_chassis(scale)
        if t > 20:
            self._draw_bubbles(rise=(t - 20) * 0.12, radius=1)

    def _draw_bubbles(self, rise, radius):
        phase = self.dead_timer * 0.2
        for i in range(min(4, 1 + self.dead_timer // 25)):
            bx = self.x + int(sin(phase + i * 1.9) * 5)
            by = self.y - 3 - int(rise) - i * 3
            circ(int(bx), int(by), radius, 6)

    def _draw_fire(self):
        t = self.dead_timer
        phase = t * 0.2
        for i in range(4):                       # rising smoke
            x = self.x + int(sin(phase + i * 1.5) * 5)
            y = self.y - 10 - i * 4 - t // 3
            r = 3 + i
            if t > 30:
                circ(int(x), int(y), r, 5)
            circ(int(x), int(y), max(1, r - 1), 0)
        for i in range(5):                       # flames
            x = self.x + int(sin(phase * 1.3 + i * 1.1) * 4)
            y = self.y - 2 - int(sin(phase + i) * 3) - i * 2
            circ(int(x), int(y), max(1, 4 - i), 10 if (t + i) % 3 == 0 else 9)
        for i in range(3):                       # embers
            x = self.x + int(sin(phase * 0.7 + i * 2.3) * 6)
            y = self.y + int(cos(phase * 0.5 + i) * 3)
            pset(int(x), int(y), 8)


# ===========================================================================
# Game — state machine, world generation, HUD and the Pyxel loop
# ===========================================================================
class Game:
    def __init__(self):
        init(SCREEN_W, SCREEN_H, title="Micro Racers", fps=30)
        load("res5.pyxres")
        define_sounds()
        self.cars = [
            Car(1, 12, 1, (KEY_A, KEY_D, KEY_W, KEY_S), 0, (525, 80)),
            Car(2, 8, 2, (KEY_LEFT, KEY_RIGHT, KEY_UP, KEY_DOWN), 1, (525, 110)),
        ]
        self.reset()
        run(self.update, self.draw)

    def reset(self):
        for car in self.cars:
            car.reset()
        self.state = "title"          # title -> countdown -> racing -> finished
        self.countdown_start = 0
        self.winner = 0
        self.collision_cd = 0
        self.spectator_hit_cd = 0
        self.skidmarks = []
        self.spectators = []
        self.lake_waves = []
        self.world_ready = False      # spectators + waves need the track rendered

    # --- world generation (needs the rendered track to read pixels) --------
    def generate_world(self):
        self.spectators = generate_spectators()
        self.lake_waves = []
        for y in range(HUD_H + 2, SCREEN_H - 2, 4):
            for x in range(2, SCREEN_W - 2, 4):
                if pget(x, y) != COL_LAKE_DEEP:
                    continue
                near = sum(
                    1 for dx, dy in ((4, 0), (-4, 0), (0, 4), (0, -4))
                    if pget(clamp(x + dx, 0, SCREEN_W - 1), clamp(y + dy, HUD_H, SCREEN_H - 1))
                    in (COL_LAKE_DEEP, COL_LAKE_ICE))
                if near >= 2:
                    self.lake_waves.append((x, y, random() * pi * 2))
        self.world_ready = True

    # --- update ------------------------------------------------------------
    def update(self):
        if self.collision_cd > 0:
            self.collision_cd -= 1
        if self.spectator_hit_cd > 0:
            self.spectator_hit_cd -= 1

        if self.state == "title":
            stop(0)
            stop(1)
            if btnp(KEY_SPACE):
                self.state = "countdown"
                self.countdown_start = pyxel.frame_count
            return

        render_track()                       # so terrain/checkpoint reads are fresh
        if not self.world_ready:
            self.generate_world()

        if self.state == "countdown":
            self.update_countdown()
        elif self.state == "racing":
            self.update_racing()
        elif self.state == "finished":
            self.update_finished()

    def update_countdown(self):
        stop(0)
        stop(1)
        elapsed = pyxel.frame_count - self.countdown_start
        if elapsed in (1, 30, 60):
            play(3, SND_COUNTDOWN_BEEP)
        if elapsed >= 90:
            play(3, SND_COUNTDOWN_GO)
            self.state = "racing"

    def update_finished(self):
        stop(0)
        stop(1)
        for car in self.cars:
            if car.dead:
                car.dead_timer += 1          # let death animations finish
            car.update_exhaust()
        if btnp(KEY_R):
            self.reset()

    def update_racing(self):
        # Checkpoints (read at the pre-move position, like the original).
        events = [car.update_checkpoints() for car in self.cars if not car.dead]
        if "cp" in events:
            play(3, SND_CHECKPOINT)
        if "lap" in events:
            play(3, SND_LAP)

        # A dead car's animation plays out, then the race ends.
        for car in self.cars:
            if car.dead:
                car.dead_timer += 1
                if car.dead_timer == 90:
                    play(3, SND_VICTORY)
                    self.state = "finished"

        for car in self.cars:
            if not car.dead:
                car.drive()
                hazard = car.hazard()
                if hazard:
                    self.kill(car, hazard)
        self.resolve_collision()

        for car in self.cars:
            if not car.dead:
                car.update_engine()
                if car.skidding():
                    self.add_skidmarks(car)
            car.spawn_exhaust()
            car.update_exhaust()

        positions = [(c.x, c.y) for c in self.cars if not c.dead]
        for spectator in self.spectators:
            if spectator.update(positions) and self.spectator_hit_cd <= 0:
                play(2, SND_CRASH)
                self.spectator_hit_cd = 10

        for car in self.cars:
            if not car.dead and car.lap >= LAPS_TO_WIN and self.winner == 0:
                self.win(car.number)

    def kill(self, car, hazard):
        car.kill(hazard)
        if self.winner == 0:                 # the surviving player wins
            self.winner = 2 if car.number == 1 else 1

    def win(self, number):
        self.winner = number
        stop(0)
        stop(1)
        play(3, SND_VICTORY)
        self.state = "finished"

    def resolve_collision(self):
        a, b = self.cars
        if a.dead or b.dead:
            return
        d = ddist(a.x, a.y, b.x, b.y)
        if 0 < d < CAR_RADIUS:
            ox, oy = (a.x - b.x) / d, (a.y - b.y) / d
            push = (CAR_RADIUS - d) / 2 + 1 + CAR_REPEL_FORCE
            a.x += ox * push
            a.y += oy * push
            b.x -= ox * push
            b.y -= oy * push
            a.speed *= 0.65
            b.speed *= 0.65
            a.clamp_position()
            b.clamp_position()
            if self.collision_cd <= 0:
                play(2, SND_CRASH)
                self.collision_cd = 15

    def add_skidmarks(self, car):
        sa, ca = sin(car.angle), cos(car.angle)
        for wx, wy in ((-5, 5), (5, 5)):     # the two rear wheels
            x = int(car.x + wx * ca - wy * sa)
            y = int(car.y + wx * sa + wy * ca)
            self.skidmarks.append((x, y))
        while len(self.skidmarks) > MAX_SKIDMARKS:
            self.skidmarks.pop(0)

    # --- draw --------------------------------------------------------------
    def draw(self):
        if self.state == "title":
            draw_title_art(pyxel.frame_count)
            return

        render_track()
        self.draw_lake_waves()
        for x, y in self.skidmarks:
            pset(x, y, 0)
            pset(x + 1, y, 1)
            pset(x + 1, y - 1, 1)
            pset(x - 1, y + 1, 1)
        for car in self.cars:
            car.draw_exhaust()
        for spectator in self.spectators:
            spectator.draw()
        for car in self.cars:
            car.draw()

        self.draw_hud()
        self.draw_overlays()

    def draw_lake_waves(self):
        t = pyxel.frame_count * 0.1
        for wx, wy, phase in self.lake_waves:
            wave = sin(t + phase + wx * 0.04 + wy * 0.03)
            if wave > 0.2:                       # shimmer drifting across the surface
                px = wx + int(sin(t * 0.6 + phase) * 2)
                py = wy + int(cos(t * 0.4 + phase) * 1)
                if 0 <= px < SCREEN_W and HUD_H <= py < SCREEN_H:
                    pset(px, py, 6)
                    if wave > 0.5 and px + 1 < SCREEN_W:
                        pset(px + 1, py, 12)     # bright wave crest
            elif wave < -0.4:
                pset(wx, wy, 1)                  # dark trough

    def draw_hud(self):
        p1, p2 = self.cars
        text(4, 3, "P1", 12)
        self.draw_speed_bar(46, 4, p1.speed, 12)
        text(4, 11, "Lap %d/%d" % (min(p1.lap + 1, LAPS_TO_WIN), LAPS_TO_WIN), 7)
        text(60, 11, "CP:%d/%d" % (sum(p1.cp_passed), NUM_CHECKPOINTS), 10)

        text(SCREEN_W - 18, 3, "P2", 8)
        self.draw_speed_bar(SCREEN_W - 96, 4, p2.speed, 8)
        text(SCREEN_W - 56, 11, "Lap %d/%d" % (min(p2.lap + 1, LAPS_TO_WIN), LAPS_TO_WIN), 7)
        text(SCREEN_W - 130, 11, "CP:%d/%d" % (sum(p2.cp_passed), NUM_CHECKPOINTS), 10)

        text(SCREEN_W // 2 - 20, 3, "%d LAPS" % LAPS_TO_WIN, 7)
        if p1.dead:
            text(4, 3, "P1 OUT", 8)
        if p2.dead:
            text(SCREEN_W - 30, 3, "P2 OUT", 8)

    @staticmethod
    def draw_speed_bar(x, y, speed, col):
        rect(x, y, int(abs(speed) / MAX_SPEED_TRACK * 50), 4, col)
        rectb(x, y, 50, 4, 5)

    def draw_overlays(self):
        cx, cy = SCREEN_W // 2, SCREEN_H // 2
        elapsed = pyxel.frame_count - self.countdown_start
        if self.state == "countdown":
            text(cx - 2, cy, "3" if elapsed < 30 else "2" if elapsed < 60 else "1",
                 10 if elapsed >= 60 else 7)
        elif self.state == "racing" and 90 <= elapsed < 110:
            text(cx - 6, cy, "GO!", 11)
        elif self.state == "finished":
            self.draw_finish_banner()

    def draw_finish_banner(self):
        p1, p2 = self.cars
        bw, bh = 220, 70
        bx, by = (SCREEN_W - bw) // 2, (SCREEN_H - bh) // 2
        rect(bx, by, bw, bh, 0)
        rectb(bx, by, bw, bh, 7)
        if self.winner == 1:
            msg = "P2 CRASHED! P1 WINS!" if p2.dead else "PLAYER 1 (BLUE) WINS!"
            col = 12
        else:
            msg = "P1 CRASHED! P2 WINS!" if p1.dead else "PLAYER 2 (RED) WINS!"
            col = 8
        text(bx + 30, by + 18, msg, col)
        text(bx + 55, by + 45, "Press R to restart", 13)


Game()