import pyxel
import random

# ══════════════════════════════════════════════════════════
#  KONSTANTEN
# ══════════════════════════════════════════════════════════
SCREEN_W  = 345
SCREEN_H  = 236
CELL      = 16
GRID_SIZE = 10

LEFT_X  = 4
RIGHT_X = LEFT_X + GRID_SIZE * CELL + 16
GRID_Y  = 32

# Farben
BLACK      = 0
DARK_BLUE  = 1
PURPLE     = 2
DARK_GRN   = 3
BROWN      = 4
DARK_GRAY  = 5
LIGHT_GRAY = 6
WHITE      = 7
RED        = 8
ORANGE     = 9
YELLOW     = 10
GREEN      = 11
BLUE       = 12
CYAN       = 12
INDIGO     = 13
PINK       = 14
PEACH      = 15

COL_BG       = DARK_BLUE
COL_GRID     = INDIGO
COL_TEXT     = WHITE
COL_DIM      = DARK_GRAY
COL_CURSOR   = YELLOW
COL_PREVIEW  = GREEN
COL_CONFLICT = RED
COL_HIT      = RED
COL_SUNK     = ORANGE

# Zellzustände
EMPTY   = 0
SHIP_OK = 1
MISS    = 2
HIT     = 3
SUNK    = 4

# Waffenmodi
MODE_NORMAL  = 0
MODE_BOMB    = 1
MODE_TORPEDO = 2

# Spielphasen
PHASE_PLACE   = "place"
PHASE_HANDOFF = "handoff"
PHASE_BATTLE  = "battle"
PHASE_RESULT  = "result"
PHASE_TORPEDO = "torpedo_anim"
PHASE_EXPLODE = "explode"
PHASE_OVER    = "over"
PHASE_TITLE  = "title"

SHIPS = [5, 4, 3, 3, 2]

KEY_INITIAL_DELAY = 20
KEY_REPEAT_RATE   = 4
KEY_REVEAL        = pyxel.KEY_A

TORP_SPEED  = 3
EXPLODE_DUR = 45

# ══════════════════════════════════════════════════════════
#  SPRITES
# ══════════════════════════════════════════════════════════

def load_sprites():
    im = pyxel.images[1]

    # ── Wasser – 3 Frames
    im.set(0, 0, [
        "1111166666cccc66",
        "1116666ccccccc6c",
        "116cccccccccc6cc",
        "66cccccccc6ccc66",
        "6cccccc6ccc66611",
        "ccccc6cc66661111",
        "cccc6c6666111166",
        "ccc6c666111166cc",
        "cc6c6661116666cc",
        "c6c6666116666ccc",
        "66666661166ccccc",
        "6666661166cccccc",
        "666661166ccccccc",
        "66661166cccccc6c",
        "6661166ccccccc66",
        "661166cccccc6666",
    ])
    im.set(16, 0, [
        "cc666611116666cc",
        "6666661116666ccc",
        "666661166ccccccc",
        "66661166cccccccc",
        "6661166ccccccccc",
        "111166cccccccccc",
        "11166ccccccccccc",
        "1166cccccccccc6c",
        "166ccccccccccc66",
        "66cccccccccc6666",
        "6ccccccccc666666",
        "cccccccccc66cccc",
        "cccccccc66cccccc",
        "cccccc6666cccccc",
        "ccccc66666cccccc",
        "cccc666666cccccc",
    ])
    im.set(32, 0, [
        "1166cccccccc6666",
        "166ccccccccc6666",
        "66cccccccc666611",
        "6ccccccc66661111",
        "cccccc6666111166",
        "ccccc66661111666",
        "cccc666611116666",
        "ccc6666111166ccc",
        "cc66661116666ccc",
        "c6666111666666cc",
        "66661116666ccccc",
        "6661166ccccccccc",
        "661166cccccccccc",
        "61166ccccccccccc",
        "1166cccccccccccc",
        "166ccccccccccccc",
    ])

    # ── Explosionen – 5 Frames
    # Frame 0 – Treffer X (rot)
    im.set(0, 16, [
        "8800000000000088",
        "8880000000000888",
        "0888000000008880",
        "0088800000088800",
        "0008888000888000",
        "0000888888880000",
        "0000088888800000",
        "0000008888000000",
        "0000008888000000",
        "0000088888800000",
        "0000888888880000",
        "0008888000888000",
        "0088800000088800",
        "0888000000008880",
        "8880000000000888",
        "8800000000000088",
    ])
    # Frame 1 – klein orange
    im.set(16, 16, [
        "0000000000000000",
        "0000000000000000",
        "0000009999000000",
        "0000099999900000",
        "0000999999990000",
        "0009999999999000",
        "0009999999999000",
        "0009999999999000",
        "0009999999999000",
        "0009999999999000",
        "0000999999990000",
        "0000099999900000",
        "0000009999000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
    ])
    # Frame 2 – mittel gelb-orange
    im.set(32, 16, [
        "0000000900000000",
        "0000009999000000",
        "0000099999900000",
        "0000999999990000",
        "0009999999999000",
        "0099999999999900",
        "0999999999999990",
        "9999999999999999",
        "9999999999999999",
        "0999999999999990",
        "0099999999999900",
        "0009999999999000",
        "0000999999990000",
        "0000099999900000",
        "0000009999000000",
        "0000000900000000",
    ])
    # Frame 3 – gross gelb
    im.set(48, 16, [
        "0a00000000000a00",
        "a0aaaa0000aaaa0a",
        "0aaaaaa00aaaaaa0",
        "0aaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "aaaaaaaaaaaaaa00",
        "0aaaaaaaaaaaaa00",
        "0aaaaaa00aaaaaa0",
        "a0aaaa0000aaaa0a",
        "0a00000000000a00",
        "0000000000000000",
    ])
    # Frame 4 – verblassend rot
    im.set(64, 16, [
        "8000000000000008",
        "0800000000000080",
        "0080000000000800",
        "0008000000008000",
        "0000800000080000",
        "0000080000800000",
        "0000008008000000",
        "0000000000000000",
        "0000000000000000",
        "0000008008000000",
        "0000080000800000",
        "0000800000080000",
        "0008000000008000",
        "0080000000000800",
        "0800000000000080",
        "8000000000000008",
    ])

    # ── Torpedo horizontal
    im.set(0, 32, [
        "0000000000000000",
        "0000000000000000",
        "000000aaa9000000",
        "0000aaaa99000000",
        "00aaaaaa990000000",
        "9999999aa9000000",
        "9999999aa9000000",
        "00aaaaaa990000000",
        "0000aaaa99000000",
        "000000aaa9000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
    ])
    # ── Torpedo vertikal
    im.set(16, 32, [
        "0000000000000000",
        "0000000aa0000000",
        "000000a99a000000",
        "00000aa99aa00000",
        "0000099999900000",
        "0000099999900000",
        "0000099999900000",
        "0000099999900000",
        "0000099999900000",
        "0000099999900000",
        "0000009999000000",
        "0000000990000000",
        "0000000900000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
    ])
    # ── Torpedo-Spur horizontal
    im.set(32, 32, [
        "0000000000000000",
        "6000000000000000",
        "0060000000000000",
        "0006000000000000",
        "0000600000000000",
        "0000060000000000",
        "0000006000000000",
        "0000000600000000",
        "0000000060000000",
        "0000000006000000",
        "0000000000600000",
        "0000000000060000",
        "0000000000006000",
        "0000000000000600",
        "0000000000000060",
        "0000000000000006",
    ])
    # ── Torpedo-Spur vertikal
    im.set(48, 32, [
        "0060000000600000",
        "0000000000000000",
        "0006000000060000",
        "0000000000000000",
        "0000600000006000",
        "0000000000000000",
        "0000060000000600",
        "0000000000000000",
        "0000006000000060",
        "0000000000000000",
        "0000000600000006",
        "0000000000000000",
        "0000000060000000",
        "0000000000000000",
        "0000000006000000",
        "0000000000000000",
    ])
    # ── Miss-Splash
    im.set(64, 32, [
        "0000000000000000",
        "0000006000000000",
        "0000666000000000",
        "0006666600000000",
        "006666cc66000000",
        "006cc6cc6c600000",
        "0006cc6cc66000000",
        "00006c6c660000000",
        "0000066666000000",
        "0000006660000000",
        "0000000600000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
        "0000000000000000",
    ])
    # ── SUNK / WRACK
    im.set(80, 32, [
        "4445444454444544",
        "4544494444944445",
        "5449444944449444",
        "4498944494498944",
        "498a89498998a894",
        "4498948a98998944",
        "444949a949449444",
        "5444949494444945",
        "4494444944944444",
        "4989449894494984",
        "98a8998a998898a9",
        "498948a9849a8984",
        "44949a9499a94944",
        "4449894989498944",
        "5444944494449445",
        "4544445444544454"
    ])

# ══════════════════════════════════════════════════════════
#  SOUNDEFFEKTE
# ══════════════════════════════════════════════════════════
def setup_sounds():
    # 0: Schiff platziert
    pyxel.sounds[0].set_notes("c3 e3")
    pyxel.sounds[0].set_tones("s")
    pyxel.sounds[0].set_volumes("6")
    pyxel.sounds[0].set_effects("n")
    pyxel.sounds[0].speed = 8

    # 1: Wasser / Miss
    pyxel.sounds[1].set_notes("a3 g3")
    pyxel.sounds[1].set_tones("s")
    pyxel.sounds[1].set_volumes("6")
    pyxel.sounds[1].set_effects("f")
    pyxel.sounds[1].speed = 7

    # 2: Treffer
    pyxel.sounds[2].set_notes("c4 a3 f3")
    pyxel.sounds[2].set_tones("s")
    pyxel.sounds[2].set_volumes("7")
    pyxel.sounds[2].set_effects("f")
    pyxel.sounds[2].speed = 5

    # 3: Versunken
    pyxel.sounds[3].set_notes("c4 a3 f3 d3")
    pyxel.sounds[3].set_tones("s")
    pyxel.sounds[3].set_volumes("7")
    pyxel.sounds[3].set_effects("f")
    pyxel.sounds[3].speed = 7

    # 4: Sieg
    pyxel.sounds[4].set_notes("c3 e3 g3 c3")
    pyxel.sounds[4].set_tones("s")
    pyxel.sounds[4].set_volumes("7")
    pyxel.sounds[4].set_effects("n")
    pyxel.sounds[4].speed = 10

    # 5: Ungültige Platzierung
    pyxel.sounds[5].set_notes("b2 a2")
    pyxel.sounds[5].set_tones("p")
    pyxel.sounds[5].set_volumes("5")
    pyxel.sounds[5].set_effects("n")
    pyxel.sounds[5].speed = 6

# ══════════════════════════════════════════════════════════
#  KEY-REPEAT
# ══════════════════════════════════════════════════════════
_key_held = {}

def key_active(key):
    if pyxel.btn(key):
        _key_held[key] = _key_held.get(key, 0) + 1
        held = _key_held[key]
        if held == 1:
            return True
        if held > KEY_INITIAL_DELAY and (held - KEY_INITIAL_DELAY) % KEY_REPEAT_RATE == 0:
            return True
        return False
    else:
        _key_held[key] = 0
        return False

# ══════════════════════════════════════════════════════════
#  HILFSFUNKTIONEN
# ══════════════════════════════════════════════════════════

def make_grid():
    return [[EMPTY] * GRID_SIZE for _ in range(GRID_SIZE)]

def ship_cells(row, col, length, horiz):
    return [(row + (0 if horiz else i), col + (i if horiz else 0))
            for i in range(length)]

def can_place(grid, cells):
    cell_set = set(cells)
    for r, c in cells:
        if not (0 <= r < GRID_SIZE and 0 <= c < GRID_SIZE):
            return False
        if grid[r][c] != EMPTY:
            return False
        for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
            nr, nc = r+dr, c+dc
            if (nr, nc) in cell_set:
                continue
            if 0 <= nr < GRID_SIZE and 0 <= nc < GRID_SIZE:
                if grid[nr][nc] == SHIP_OK:
                    return False
    return True

def place_ship(grid, cells):
    for r, c in cells:
        grid[r][c] = SHIP_OK

def all_sunk(grid):
    return not any(grid[r][c] == SHIP_OK
                   for r in range(GRID_SIZE)
                   for c in range(GRID_SIZE))

def check_and_mark_sunk(grid, ships):
    for ship in ships:
        if all(grid[r][c] == HIT for r, c in ship):
            for r, c in ship:
                grid[r][c] = SUNK
            pyxel.play(3, 3, loop=False)

# ══════════════════════════════════════════════════════════
#  SPIELSTAND
# ══════════════════════════════════════════════════════════

class PlayerState:
    def __init__(self):
        self.grid     = make_grid()
        self.ships    = []
        self.bombs    = 2
        self.torpedos = 1
        self.shots    = 0
        self.hits     = 0

class Game:
    def __init__(self):
        self.p           = [PlayerState(), PlayerState()]
        self.current     = 0
        self.phase       = PHASE_TITLE
        self.cur_r       = 0
        self.cur_c       = 0
        self.cur_horiz   = True
        self.ship_idx    = 0
        self.bc_r        = 0
        self.bc_c        = 0
        self.mode        = MODE_NORMAL
        self.torp_horiz  = True
        self.last_was_hit = False
        self.last_cells  = []
        self.anim_timer  = 0
        self.stay_on_turn = False
        self.torp_path   = []
        self.torp_progress = 0.0
        self.torp_total_px = 0
        self.torp_was_hit  = False
        self.torp_affected = []
        self.explode_timer = 0
        self.message     = ""
        self.winner      = -1
        self.handoff_msg = ""
        self.next_phase  = PHASE_PLACE
        self.miss_time = {}

g = Game()

# ══════════════════════════════════════════════════════════
#  SCHÜSSE
# ══════════════════════════════════════════════════════════

def opp(current):
    return 1 - current

def fire_normal(g, r, c):
    grid = g.p[opp(g.current)].grid
    if grid[r][c] in (MISS, HIT, SUNK):
        g.message = "Bereits beschossen!"
        return False, []
    g.p[g.current].shots += 1
    was_hit = grid[r][c] == SHIP_OK
    if was_hit:
        grid[r][c] = HIT
        g.p[g.current].hits += 1
        g.message = "TREFFER! Nochmals schiessen!"
        pyxel.play(2, 2, loop=False)
    else:
        grid[r][c] = MISS
        g.miss_time[(opp(g.current), r, c)] = pyxel.frame_count
        g.message = "Wasser..."
        pyxel.play(1, 1, loop=False)
    check_and_mark_sunk(grid, g.p[opp(g.current)].ships)
    return True, [(r, c)]

def fire_bomb(g, r, c):
    p    = g.p[g.current]
    grid = g.p[opp(g.current)].grid
    if p.bombs <= 0:
        g.message = "Keine Bomben mehr!"
        return False, []
    p.bombs -= 1
    p.shots += 1
    hits = 0
    affected = []
    for dr in range(-1, 2):
        for dc in range(-1, 2):
            nr, nc = r+dr, c+dc
            if 0 <= nr < GRID_SIZE and 0 <= nc < GRID_SIZE:
                if grid[nr][nc] == SHIP_OK:
                    grid[nr][nc] = HIT
                    p.hits += 1
                    hits += 1
                    affected.append((nr, nc))
                elif grid[nr][nc] == EMPTY:
                    grid[nr][nc] = MISS
                    affected.append((nr, nc))
    check_and_mark_sunk(grid, g.p[opp(g.current)].ships)
    g.message = ("BOMBE! " + str(hits) + " Treffer! Nochmals schiessen!"
                 if hits else "Bombe: Wasser!")
    if hits:
        pyxel.play(2, 2, loop=False)
    else:
        pyxel.play(1, 1, loop=False)
    return True, affected

def start_torpedo(g, row, col, horiz):
    p    = g.p[g.current]
    grid = g.p[opp(g.current)].grid
    if p.torpedos <= 0:
        g.message = "Kein Torpedo mehr!"
        return False
    p.torpedos -= 1
    p.shots    += 1
    path = []
    if horiz:
        for c in range(GRID_SIZE):
            path.append((row, c))
            if grid[row][c] == SHIP_OK:
                break
    else:
        for r in range(GRID_SIZE):
            path.append((r, col))
            if grid[r][col] == SHIP_OK:
                break
    g.torp_path      = path
    g.torp_progress  = 0.0
    g.torp_total_px  = len(path) * CELL
    g.torp_was_hit   = False
    g.torp_affected  = []
    g.explode_timer  = 0
    g.phase          = PHASE_TORPEDO
    return True

def apply_torpedo_damage(g):
    p    = g.p[g.current]
    grid = g.p[opp(g.current)].grid
    hit  = False
    affected = []
    for r, c in g.torp_path:
        if grid[r][c] == SHIP_OK:
            grid[r][c] = HIT
            p.hits += 1
            hit = True
            affected.append((r, c))
            break
        elif grid[r][c] == EMPTY:
            grid[r][c] = MISS
            affected.append((r, c))
    check_and_mark_sunk(grid, g.p[opp(g.current)].ships)
    g.torp_was_hit  = hit
    g.torp_affected = affected
    if hit:
        g.message = "TORPEDO! Treffer! Nochmals schiessen!"
    else:
        g.message = "Torpedo: Wasser..."

# ══════════════════════════════════════════════════════════
#  NACH RESULT
# ══════════════════════════════════════════════════════════

def go_after_result(g):
    if all_sunk(g.p[opp(g.current)].grid):
        g.winner = g.current
        g.phase  = PHASE_OVER
        pyxel.play(0, 4, loop=False)
    elif g.stay_on_turn:
        g.phase = PHASE_BATTLE
    else:
        next_p = opp(g.current)
        g.current    = next_p
        g.phase      = PHASE_HANDOFF
        g.next_phase = PHASE_BATTLE
        pname = "Spieler 1" if g.current == 0 else "Spieler 2"
        g.handoff_msg = pname + " ist dran!\nDruecke SPACE wenn bereit."

# ══════════════════════════════════════════════════════════
#  UPDATE
# ══════════════════════════════════════════════════════════

def update():
    global g
    
    if g.phase == PHASE_TITLE:
        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            g.phase = PHASE_HANDOFF
            g.current     = 0
            g.next_phase  = PHASE_PLACE
            g.handoff_msg = "Spieler 1 ist dran!\nPlatziere deine Schiffe.\nDruecke SPACE wenn bereit."
        return
    
    if g.phase == PHASE_OVER:
        if pyxel.btnp(pyxel.KEY_R):
            g = Game()
        return

    if g.phase == PHASE_HANDOFF:
        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            g.phase = g.next_phase
        return

    if g.phase == PHASE_RESULT:
        g.anim_timer += 1
        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            go_after_result(g)
        return

    if g.phase == PHASE_TORPEDO:
        g.torp_progress += TORP_SPEED
        if g.torp_progress >= g.torp_total_px:
            g.torp_progress = float(g.torp_total_px)
            apply_torpedo_damage(g)
            g.explode_timer = 0
            g.phase = PHASE_EXPLODE
        return

    if g.phase == PHASE_EXPLODE:
        g.explode_timer += 1
        if g.explode_timer >= EXPLODE_DUR:
            g.last_was_hit  = g.torp_was_hit
            g.last_cells    = g.torp_affected
            g.stay_on_turn  = g.torp_was_hit
            g.anim_timer    = 0
            g.phase         = PHASE_RESULT
        return

    if g.phase == PHASE_PLACE:
        length = SHIPS[g.ship_idx]
        if key_active(pyxel.KEY_UP):    g.cur_r = max(0, g.cur_r - 1)
        if key_active(pyxel.KEY_DOWN):  g.cur_r = min(GRID_SIZE-1, g.cur_r + 1)
        if key_active(pyxel.KEY_LEFT):  g.cur_c = max(0, g.cur_c - 1)
        if key_active(pyxel.KEY_RIGHT): g.cur_c = min(GRID_SIZE-1, g.cur_c + 1)
        if pyxel.btnp(pyxel.KEY_R): g.cur_horiz = not g.cur_horiz

        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            cells = ship_cells(g.cur_r, g.cur_c, length, g.cur_horiz)
            if can_place(g.p[g.current].grid, cells):
                place_ship(g.p[g.current].grid, cells)
                g.p[g.current].ships.append(cells)
                pyxel.play(0, 0, loop=False)
                g.ship_idx += 1
                g.cur_r, g.cur_c, g.cur_horiz = 0, 0, True
                g.message = ""
                if g.ship_idx >= len(SHIPS):
                    if g.current == 0:
                        g.current     = 1
                        g.ship_idx    = 0
                        g.phase       = PHASE_HANDOFF
                        g.next_phase  = PHASE_PLACE
                        g.handoff_msg = "Spieler 2 ist dran!\nPlatziere deine Schiffe.\nDruecke SPACE wenn bereit."
                    else:
                        g.current     = 0
                        g.ship_idx    = -1
                        g.phase       = PHASE_HANDOFF
                        g.next_phase  = PHASE_BATTLE
                        g.handoff_msg = "Alle Schiffe platziert!\nSpieler 1 beginnt.\nDruecke SPACE wenn bereit."
            else:
                g.message = "Zu nah an einem anderen Schiff!"
                pyxel.play(0, 5, loop=False)
        return

    if g.phase == PHASE_BATTLE:
        if key_active(pyxel.KEY_UP):    g.bc_r = max(0, g.bc_r - 1)
        if key_active(pyxel.KEY_DOWN):  g.bc_r = min(GRID_SIZE-1, g.bc_r + 1)
        if key_active(pyxel.KEY_LEFT):  g.bc_c = max(0, g.bc_c - 1)
        if key_active(pyxel.KEY_RIGHT): g.bc_c = min(GRID_SIZE-1, g.bc_c + 1)

        if pyxel.btnp(pyxel.KEY_1): g.mode = MODE_NORMAL
        if pyxel.btnp(pyxel.KEY_2): g.mode = MODE_BOMB
        if pyxel.btnp(pyxel.KEY_3): g.mode = MODE_TORPEDO
        if pyxel.btnp(pyxel.KEY_R) and g.mode == MODE_TORPEDO:
            g.torp_horiz = not g.torp_horiz

        if pyxel.btnp(pyxel.KEY_SPACE) or pyxel.btnp(pyxel.KEY_RETURN):
            if g.mode == MODE_NORMAL:
                fired, affected = fire_normal(g, g.bc_r, g.bc_c)
                if fired:
                    opp_grid = g.p[opp(g.current)].grid
                    was_hit = any(opp_grid[r][c] in (HIT, SUNK) for r, c in affected)
                    g.last_was_hit = was_hit
                    g.last_cells   = affected
                    g.stay_on_turn = was_hit
                    g.anim_timer   = 0
                    g.phase        = PHASE_RESULT
            elif g.mode == MODE_BOMB:
                fired, affected = fire_bomb(g, g.bc_r, g.bc_c)
                if fired:
                    opp_grid = g.p[opp(g.current)].grid
                    was_hit = any(opp_grid[r][c] in (HIT, SUNK) for r, c in affected)
                    g.last_was_hit = was_hit
                    g.last_cells   = affected
                    g.stay_on_turn = was_hit
                    g.anim_timer   = 0
                    g.phase        = PHASE_RESULT
            elif g.mode == MODE_TORPEDO:
                start_torpedo(g, g.bc_r, g.bc_c, g.torp_horiz)

# ══════════════════════════════════════════════════════════
#  DRAW – HILFSFUNKTIONEN
# ══════════════════════════════════════════════════════════

def shadow_text(x, y, s, col):
    pyxel.text(x+1, y+1, s, BLACK)
    pyxel.text(x, y, s, col)


SHIP_SPRITES_H = {
    5: (  64,  128),
    4: (  64,  144),
    3: (  64,  192),
    2: (  64, 176),
}
SHIP_SPRITES_V = {
    5: ( 128,  48),
    4: (128,  144),
    3: (112,  192),
    2: (80,  216),
}

def draw_ship_sprite(ox, oy, cells, horiz):
    if not cells:
        return
    length = len(cells)
    r0, c0 = cells[0]
    x = ox + c0 * CELL
    y = oy + r0 * CELL

    if horiz and length in SHIP_SPRITES_H:
        u, v = SHIP_SPRITES_H[length]
        pyxel.blt(x, y, 0, u, v, length * CELL, CELL)
    elif not horiz and length in SHIP_SPRITES_V:
        u, v = SHIP_SPRITES_V[length]
        pyxel.blt(x, y, 0, u, v, CELL, length * CELL)
    else:
        # Fallback: graues Rechteck
        w = length * CELL if horiz else CELL
        h = CELL if horiz else length * CELL
        pyxel.rect(x+1, y+1, w-2, h-2, DARK_GRAY)

def draw_grid(ox, oy, grid, show_ships, ships=None,
              cursor_r=-1, cursor_c=-1,
              highlight_cells=None, anim_timer=0):

    pyxel.rect(ox-1, oy-1, GRID_SIZE*CELL+2, GRID_SIZE*CELL+2, INDIGO)

    for r in range(GRID_SIZE):
        for c in range(GRID_SIZE):
            x = ox + c * CELL
            y = oy + r * CELL
            state = grid[r][c]

            if state == MISS:
                frame = (pyxel.frame_count // 25 + r + c) % 3

                # 1. Wasser normal
                pyxel.blt(x, y, 1, frame * 16, 0, 16, 16, 0)

                # 2. Leichte Abdunklung
                for py_ in range(y, y + CELL):
                    for px_ in range(x, x + CELL):
                        if (px_ - py_) % 4 == 0:
                            pyxel.pset(px_, py_, DARK_BLUE)
                # 3. Kleiner Splash
                player_id = opp(g.current) if not show_ships else g.current

                if (player_id, r, c) in g.miss_time:
                    t = pyxel.frame_count - g.miss_time[(player_id, r, c)]
                    if t < 180:
                        pyxel.blt(x, y, 1, 64, 32, 16, 16, 0)
                # 4. Miss-Markierung
                mx, my = x + CELL // 2, y + CELL // 2
                pyxel.pset(mx-2, my-2, LIGHT_GRAY)
                pyxel.pset(mx+2, my-2, LIGHT_GRAY)
                pyxel.pset(mx,   my,   LIGHT_GRAY)
                pyxel.pset(mx-2, my+2, LIGHT_GRAY)
                pyxel.pset(mx+2, my+2, LIGHT_GRAY)
            elif state == HIT:
                pyxel.rect(x+1, y+1, CELL-2, CELL-2, BROWN)
            elif state == SUNK:
                pyxel.blt(x, y, 1, 80, 32, 16, 16, 0)
                # kleine Glutpunkte
                if (pyxel.frame_count // 10) % 2 == 0:
                    pyxel.pset(x+6, y+7, ORANGE)
                    pyxel.pset(x+9, y+8, ORANGE)
            elif state == SHIP_OK and show_ships:
                pyxel.rect(x+1, y+1, CELL-2, CELL-2, DARK_GRAY)
            else:
                # Leeres Wasser animiertes Sprite
                pyxel.rect(x+1, y+1, CELL-2, CELL-2, DARK_BLUE)
                frame = (pyxel.frame_count // 25 + r + c) % 3
                pyxel.blt(x, y, 1, frame * 16, 0, 16, 16, 0)

            pyxel.rectb(x, y, CELL, CELL, COL_GRID)

            if state in (HIT, SUNK):
                ef = pyxel.frame_count % 3 if state == HIT else 2
                pyxel.blt(x, y, 1, ef * 16 + 16, 16, 16, 16, 0)

    # Schiffe zeichnen
    if show_ships and ships:
        for ship in ships:
            if not ship:
                continue
            sh = len(ship) < 2 or ship[0][0] == ship[1][0]
            draw_ship_sprite(ox, oy, ship, sh)

    # Highlight-Overlay
    if highlight_cells:
        blink_on = (anim_timer // 8) % 2 == 0
        for r, c in highlight_cells:
            x = ox + c * CELL
            y = oy + r * CELL
            state = grid[r][c]
            if state in (HIT, SUNK):
                flash_col = WHITE if blink_on else YELLOW
                pyxel.rectb(x+1, y+1, CELL-2, CELL-2, flash_col)
                exp_frame = min(anim_timer // 8, 4)
                pyxel.blt(x, y, 1, exp_frame * 16 + 16, 16, 16, 16, 0)
            elif state == MISS:
                if blink_on:
                    pyxel.rect(x+1, y+1, CELL-2, CELL-2, CYAN)

    if 0 <= cursor_r < GRID_SIZE and 0 <= cursor_c < GRID_SIZE:
        cx = ox + cursor_c * CELL
        cy = oy + cursor_r * CELL
        pyxel.rectb(cx,   cy,   CELL, CELL, YELLOW)
        pyxel.rectb(cx+1, cy+1, CELL-2, CELL-2, ORANGE)

def draw_ship_preview(ox, oy, grid, row, col, length, horiz):
    cells = ship_cells(row, col, length, horiz)
    ok    = can_place(grid, cells)
    col_p = GREEN if ok else RED
    for r, c in cells:
        if 0 <= r < GRID_SIZE and 0 <= c < GRID_SIZE:
            x = ox + c * CELL
            y = oy + r * CELL
            pyxel.rect(x+1, y+1, CELL-2, CELL-2, col_p)
            pyxel.rectb(x, y, CELL, CELL, COL_GRID)

def draw_battle_ui(show_own, highlight_cells=None, anim_timer=0):
    o = opp(g.current)
    p = g.p[g.current]
    pname = "Spieler 1" if g.current == 0 else "Spieler 2"

    own_label = "Dein Feld (A)" if not show_own else "Dein Feld"
    pyxel.text(LEFT_X,  4,  own_label, COL_DIM if not show_own else CYAN)
    pyxel.text(RIGHT_X, 4,  "Gegner",  COL_TEXT)
    pyxel.text(LEFT_X,  14, pname + " B:" + str(p.bombs) +
         " T:" + str(p.torpedos) +
         " Schuss:" + str(p.shots), COL_TEXT)

    draw_grid(LEFT_X,  GRID_Y, g.p[g.current].grid,
              show_ships=show_own, ships=g.p[g.current].ships)
    draw_grid(RIGHT_X, GRID_Y, g.p[o].grid, show_ships=False,
              highlight_cells=highlight_cells, anim_timer=anim_timer)

    torp_dir = "H" if g.torp_horiz else "V"
    labels = ["1:Normal", "2:Bombe("+str(p.bombs)+")",
              "3:Torpedo("+str(p.torpedos)+")["+torp_dir+"]"]
    bx = LEFT_X
    for i, label in enumerate(labels):
        col = YELLOW if g.mode == i else COL_DIM
        pyxel.text(bx, GRID_Y + GRID_SIZE*CELL + 4, label, col)
        bx += len(label)*4 + 4

# ══════════════════════════════════════════════════════════
#  DRAW – TORPEDO ANIMATION
# ══════════════════════════════════════════════════════════

def draw_torpedo_phase():
    o        = opp(g.current)
    show_own = pyxel.btn(KEY_REVEAL)
    horiz    = g.torp_horiz
    path     = g.torp_path
    prog     = g.torp_progress

    pyxel.text(LEFT_X,  4,  "Dein Feld (A)", COL_DIM)
    pyxel.text(RIGHT_X, 4,  "Gegner",        COL_TEXT)
    draw_grid(LEFT_X,  GRID_Y, g.p[g.current].grid,
              show_ships=show_own, ships=g.p[g.current].ships)
    draw_grid(RIGHT_X, GRID_Y, g.p[o].grid, show_ships=False)

    if not path:
        return

    cells_behind   = int(prog // CELL)
    offset_in_cell = prog - cells_behind * CELL

    for i in range(min(cells_behind, len(path))):
        r, c = path[i]
        x = RIGHT_X + c * CELL
        y = GRID_Y  + r * CELL
        pyxel.blt(x, y, 1, 48, 32 if horiz else 48, 16, 16, 0)

    if cells_behind > 0:
        r0, c0   = path[0]
        ri       = min(cells_behind, len(path)-1)
        ri_r, ri_c = path[ri]
        pyxel.line(RIGHT_X + c0*CELL + CELL//2,   GRID_Y + r0*CELL + CELL//2,
             RIGHT_X + ri_c*CELL + CELL//2, GRID_Y + ri_r*CELL + CELL//2,
             ORANGE)

    head_cell = min(cells_behind, len(path)-1)
    hr, hc    = path[head_cell]
    if horiz:
        hx = RIGHT_X + hc * CELL + int(offset_in_cell)
        hy = GRID_Y  + hr * CELL
        pyxel.blt(hx, hy, 1, 0, 32, 16, 16, 0)
    else:
        hx = RIGHT_X + hc * CELL
        hy = GRID_Y  + hr * CELL + int(offset_in_cell)
        pyxel.blt(hx, hy, 1, 16, 32, 16, 16, 0)

    shadow_text(LEFT_X, GRID_Y + GRID_SIZE*CELL + 8,
                "Torpedo unterwegs...", ORANGE)

# ══════════════════════════════════════════════════════════
#  DRAW – EXPLOSIONS-PHASE
# ══════════════════════════════════════════════════════════

def draw_explode_phase():
    o        = opp(g.current)
    show_own = pyxel.btn(KEY_REVEAL)
    t        = g.explode_timer
    hit      = g.torp_was_hit

    pyxel.text(LEFT_X,  4,  "Dein Feld (A)", COL_DIM)
    pyxel.text(RIGHT_X, 4,  "Gegner",        COL_TEXT)
    draw_grid(LEFT_X,  GRID_Y, g.p[g.current].grid,
              show_ships=show_own, ships=g.p[g.current].ships)
    draw_grid(RIGHT_X, GRID_Y, g.p[o].grid, show_ships=False)

    if not g.torp_path:
        return

    er, ec = g.torp_path[-1]
    ex = RIGHT_X + ec * CELL
    ey = GRID_Y  + er * CELL

    if hit:
        if   t < 10: ef = 1
        elif t < 20: ef = 2
        elif t < 32: ef = 3
        else:        ef = 4
        for dx, dy in [(0,0),(-2,-2),(2,-2),(-2,2),(2,2)]:
            pyxel.blt(ex+dx, ey+dy, 1, ef*16+16, 16, 16, 16, 0)
        pyxel.rectb(ex, ey, CELL, CELL, YELLOW if (t//4)%2==0 else RED)
        if t < 8:
            for _ in range((8-t)*3):
                pyxel.pset(random.randint(RIGHT_X, RIGHT_X+GRID_SIZE*CELL),
                     random.randint(GRID_Y, GRID_Y+GRID_SIZE*CELL), YELLOW)
        shadow_text(LEFT_X, GRID_Y+GRID_SIZE*CELL+8, "TREFFER!", RED)
    else:
        if t < 20:
            pyxel.blt(ex, ey, 1, 64, 32, 16, 16, 0)
            if (t//6)%2==0:
                pyxel.rectb(ex, ey, CELL, CELL, CYAN)
        shadow_text(LEFT_X, GRID_Y+GRID_SIZE*CELL+8, "Wasser...", CYAN)

# ══════════════════════════════════════════════════════════
#  DRAW HAUPTFUNKTION
# ══════════════════════════════════════════════════════════

def draw():
    if g.phase == PHASE_TITLE:
        pyxel.cls(COL_BG)
        for row in range(SCREEN_H // CELL + 1):
            for col in range(SCREEN_W // CELL + 1):
                frame = (row + col) % 3
                pyxel.blt(col * CELL, row * CELL, 1, frame * 16, 0, 16, 16, 0)
        for y in range(0, SCREEN_H, 3):
            for x in range(0, SCREEN_W, 3):
                pyxel.pset(x, y, DARK_BLUE)
        title  = "SCHIFFE VERSENKEN"
        sub    = "2 Spieler"
        prompt = "SPACE druecken"
        cx = SCREEN_W // 2
        cy = SCREEN_H // 2
    
        bw, bh = 210, 70
        bx = cx - bw // 2
        by = cy - bh // 2 - 8
        pyxel.rect (bx - 2, by - 2, bw + 4, bh + 4, DARK_BLUE)
        pyxel.rectb(bx - 2, by - 2, bw + 4, bh + 4, INDIGO)
        pyxel.rectb(bx - 1, by - 1, bw + 2, bh + 2, DARK_GRAY)
        
        tx = cx - len(title) * 2
        ty = cy - 20
        for ox, oy, col in [(2,2,BLACK),(1,1,DARK_BLUE),(0,0,WHITE)]:
            pyxel.text(tx + ox, ty + oy, title, col)
            
        tx2 = cx - len(title) * 2
        pyxel.text(tx2 + 1, ty + 1, title, DARK_GRAY)
        pyxel.text(tx2,     ty,     title, WHITE)
        
        # Untertitel
        sx = cx - len(sub) * 2
        for ox, oy, col in [(2,2,BLACK),(1,1,DARK_BLUE),(0,0,CYAN)]:
            pyxel.text(sx + ox, ty + 10 + oy, sub, col)
            
        # Trennlinie
        pyxel.line(cx - 55, ty + 19, cx + 55, ty + 19, INDIGO)
        pyxel.line(cx - 55, ty + 20, cx + 55, ty + 20, DARK_GRAY)
        
        # Blinkendes Prompt
        if (pyxel.frame_count // 25) % 2 == 0:
            px_ = cx - len(prompt) * 2
            pyxel.text(px_ + 1, ty + 28, prompt, DARK_GRAY)
            pyxel.text(px_,     ty + 27, prompt, YELLOW)
        return

    if g.phase == PHASE_HANDOFF:
        pyxel.cls(BLACK)
        for i in range(0, SCREEN_W, 16):
            for j in range(0, SCREEN_H, 16):
                pyxel.rectb(i, j, 16, 16, DARK_BLUE)
        lines   = g.handoff_msg.split("\n")
        bw = max(len(l) for l in lines)*4 + 20
        bh = len(lines)*12 + 16
        bx = SCREEN_W//2 - bw//2
        by = SCREEN_H//2 - bh//2
        pyxel.rect(bx, by, bw, bh, DARK_BLUE)
        pyxel.rectb(bx, by, bw, bh, CYAN)
        pyxel.rectb(bx+1, by+1, bw-2, bh-2, INDIGO)
        start_y = by + 8
        for i, ln in enumerate(lines):
            shadow_text(SCREEN_W//2 - len(ln)*2, start_y + i*12, ln, WHITE)
        return

    pyxel.cls(COL_BG)
    for i in range(0, SCREEN_W, 8):
        wave_y = (pyxel.frame_count // 30 + i // 8) % 2
        pyxel.pset(i, SCREEN_H - 4 + wave_y, INDIGO)

    if g.phase == PHASE_PLACE:
        pname  = "Spieler 1" if g.current == 0 else "Spieler 2"
        length = SHIPS[g.ship_idx]
        orient = "Horizontal" if g.cur_horiz else "Vertikal"
        shadow_text(LEFT_X, 4,
            pname + ": Schiff " + str(g.ship_idx+1) + "/" +
            str(len(SHIPS)) + "  Laenge: " + str(length), CYAN)
        pyxel.text(LEFT_X, 14, "Pfeile=Bew  R=Drehen  SPACE=Platz", COL_DIM)
        draw_grid(LEFT_X, GRID_Y, g.p[g.current].grid,
                  show_ships=True, ships=g.p[g.current].ships,
                  cursor_r=g.cur_r, cursor_c=g.cur_c)
        draw_ship_preview(LEFT_X, GRID_Y, g.p[g.current].grid,
                          g.cur_r, g.cur_c, length, g.cur_horiz)
        pyxel.text(LEFT_X, GRID_Y + GRID_SIZE*CELL + 4, "Ausrichtung: " + orient, WHITE)
        if g.message:
            shadow_text(LEFT_X, GRID_Y + GRID_SIZE*CELL + 14, g.message, RED)
        return

    if g.phase == PHASE_TORPEDO:
        draw_torpedo_phase()
        return

    if g.phase == PHASE_EXPLODE:
        draw_explode_phase()
        return

    show_own = pyxel.btn(KEY_REVEAL)

    if g.phase == PHASE_BATTLE:
        draw_battle_ui(show_own=show_own)
        if g.mode == MODE_TORPEDO:
            tdir = "H (R=wechseln)" if g.torp_horiz else "V (R=wechseln)"
            pyxel.text(LEFT_X, GRID_Y+GRID_SIZE*CELL+14, "Torpedo: "+tdir, CYAN)
        else:
            pyxel.text(LEFT_X, GRID_Y+GRID_SIZE*CELL+14, g.message, WHITE)
        pyxel.text(LEFT_X, GRID_Y+GRID_SIZE*CELL+24,
             "SPACE=Schiessen  A=Mein Feld", COL_DIM)
        pyxel.rectb(RIGHT_X+g.bc_c*CELL,   GRID_Y+g.bc_r*CELL,   CELL, CELL, YELLOW)
        pyxel.rectb(RIGHT_X+g.bc_c*CELL+1, GRID_Y+g.bc_r*CELL+1, CELL-2, CELL-2, ORANGE)
        return

    if g.phase == PHASE_RESULT:
        draw_battle_ui(show_own=show_own,
                       highlight_cells=g.last_cells,
                       anim_timer=g.anim_timer)
        msg_col = RED if g.last_was_hit else COL_DIM
        shadow_text(LEFT_X, GRID_Y+GRID_SIZE*CELL+14, g.message, msg_col)
        if (g.anim_timer // 15) % 2 == 0:
            if g.stay_on_turn:
                shadow_text(LEFT_X, GRID_Y+GRID_SIZE*CELL+24,
                            "SPACE = nochmals schiessen!", YELLOW)
            else:
                pyxel.text(LEFT_X, GRID_Y+GRID_SIZE*CELL+24,
                     "SPACE = weiter  A=Mein Feld", WHITE)
        return

    if g.phase == PHASE_OVER:
        pyxel.cls(BLACK)
        for i in range(0, SCREEN_W, 16):
            for j in range(0, SCREEN_H, 16):
                pyxel.rectb(i, j, 16, 16, DARK_BLUE)
        wname = "Spieler 1" if g.winner == 0 else "Spieler 2"
        pw = g.p[g.winner]
        pl = g.p[opp(g.winner)]
        pyxel.rect(30, 50, 196, 100, DARK_BLUE)
        pyxel.rectb(30, 50, 196, 100, YELLOW)
        pyxel.rectb(31, 51, 194, 98, ORANGE)
        shadow_text(SCREEN_W//2 - len(wname+"  gewinnt!")*2,
                    62, wname + "  gewinnt!", YELLOW)
        pyxel.text(44, 82, "Gewinner  - Schuss:" + str(pw.shots) +
             "  Tref:" + str(pw.hits), WHITE)
        pyxel.text(44, 94, "Verlierer - Schuss:" + str(pl.shots) +
             "  Tref:" + str(pl.hits), COL_DIM)
        if (pyxel.frame_count // 20) % 2 == 0:
            shadow_text(60, 114, "R druecken = Neues Spiel", CYAN)

# ══════════════════════════════════════════════════════════
#  START
# ══════════════════════════════════════════════════════════
pyxel.init(SCREEN_W, SCREEN_H, title="Schiffe Versenken 2 Spieler")
pyxel.load("res.pyxres")
pyxel.load("1res.pyxres")
load_sprites()
setup_sounds()
pyxel.playm(0, loop=True)
pyxel.run(update, draw)