#!/usr/bin/env python3
"""Bab 12 Reinforcement Learning playground.

Small, offline, readable demos for:
- sales bandit (epsilon-greedy and UCB)
- gridworld SARSA and Q-learning
- mini Pong with tabular Q-learning
- policy-gradient bandit
- soft/SAC-style entropy pricing bandit
- toy RLHF preference reward model

The script uses only Python's standard library so it runs in terminal, VS Code,
Jupyter, Google Colab, and Kaggle without extra installation.
"""

from __future__ import annotations

import argparse
import json
import math
import os
import random
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple

Action = int
State = Tuple[int, ...]
QTable = Dict[Tuple[State, Action], float]


def softmax(values: Sequence[float], temperature: float = 1.0) -> List[float]:
    """Numerically stable softmax."""
    if temperature <= 0:
        raise ValueError("temperature must be positive")
    scaled = [v / temperature for v in values]
    m = max(scaled)
    exps = [math.exp(v - m) for v in scaled]
    total = sum(exps)
    return [v / total for v in exps]


def weighted_choice(rng: random.Random, probs: Sequence[float]) -> int:
    """Return an index sampled from a probability vector."""
    r = rng.random()
    acc = 0.0
    for i, p in enumerate(probs):
        acc += p
        if r <= acc:
            return i
    return len(probs) - 1


def argmax_random_tie(rng: random.Random, values: Sequence[float]) -> int:
    """Argmax with random tie breaking."""
    best = max(values)
    candidates = [i for i, v in enumerate(values) if abs(v - best) < 1e-12]
    return rng.choice(candidates)


def epsilon_greedy_action(
    rng: random.Random,
    q: QTable,
    state: State,
    actions: Sequence[Action],
    epsilon: float,
) -> Action:
    if rng.random() < epsilon:
        return rng.choice(list(actions))
    values = [q[(state, a)] for a in actions]
    return actions[argmax_random_tie(rng, values)]


# ---------------------------------------------------------------------------
# Lab 1: Sales bandit
# ---------------------------------------------------------------------------


def sample_sales_reward(rng: random.Random, action: int) -> float:
    """Synthetic profit for three promo choices.

    0 = no discount, 1 = 10% discount, 2 = bundle. The best expected action is
    intentionally not the one with the largest variance.
    """
    means = [42.0, 48.0, 53.0]
    stds = [6.0, 11.0, 8.0]
    return rng.gauss(means[action], stds[action])


def run_epsilon_greedy_bandit(
    rng: random.Random,
    steps: int = 500,
    epsilon: float = 0.1,
) -> Dict[str, object]:
    q = [0.0, 0.0, 0.0]
    counts = [0, 0, 0]
    total = 0.0
    for _ in range(steps):
        if rng.random() < epsilon:
            action = rng.randrange(3)
        else:
            action = argmax_random_tie(rng, q)
        reward = sample_sales_reward(rng, action)
        counts[action] += 1
        q[action] += (reward - q[action]) / counts[action]
        total += reward
    return {
        "strategy": "epsilon_greedy",
        "epsilon": epsilon,
        "total_reward": round(total, 3),
        "estimated_values": [round(v, 3) for v in q],
        "action_counts": counts,
        "best_action_name": ["tanpa_diskon", "diskon_10", "bundling"][argmax_random_tie(rng, q)],
    }


def run_ucb_bandit(rng: random.Random, steps: int = 500, c: float = 2.0) -> Dict[str, object]:
    q = [0.0, 0.0, 0.0]
    counts = [0, 0, 0]
    total = 0.0
    for t in range(1, steps + 1):
        untried = [a for a, n in enumerate(counts) if n == 0]
        if untried:
            action = untried[0]
        else:
            scores = [q[a] + c * math.sqrt(math.log(t) / counts[a]) for a in range(3)]
            action = argmax_random_tie(rng, scores)
        reward = sample_sales_reward(rng, action)
        counts[action] += 1
        q[action] += (reward - q[action]) / counts[action]
        total += reward
    return {
        "strategy": "ucb",
        "c": c,
        "total_reward": round(total, 3),
        "estimated_values": [round(v, 3) for v in q],
        "action_counts": counts,
        "best_action_name": ["tanpa_diskon", "diskon_10", "bundling"][argmax_random_tie(rng, q)],
    }


def sales_bandit_demo(seed: int = 7) -> Dict[str, object]:
    return {
        "epsilon_greedy": run_epsilon_greedy_bandit(random.Random(seed), epsilon=0.1),
        "ucb": run_ucb_bandit(random.Random(seed), c=2.0),
    }


# ---------------------------------------------------------------------------
# Lab 2: Gridworld SARSA and Q-learning
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class Gridworld:
    width: int = 5
    height: int = 5
    start: State = (0, 0)
    goal: State = (4, 4)
    trap: State = (3, 1)
    wall: State = (2, 2)

    @property
    def actions(self) -> Tuple[Action, ...]:
        # 0 up, 1 right, 2 down, 3 left
        return (0, 1, 2, 3)

    def step(self, state: State, action: Action) -> Tuple[State, float, bool]:
        x, y = state
        if action == 0:
            y -= 1
        elif action == 1:
            x += 1
        elif action == 2:
            y += 1
        elif action == 3:
            x -= 1
        else:
            raise ValueError(f"unknown action {action}")
        x = max(0, min(self.width - 1, x))
        y = max(0, min(self.height - 1, y))
        next_state = (x, y)
        if next_state == self.wall:
            next_state = state
        if next_state == self.goal:
            return next_state, 20.0, True
        if next_state == self.trap:
            return next_state, -20.0, True
        return next_state, -1.0, False


def train_gridworld(
    algorithm: str,
    seed: int = 11,
    episodes: int = 700,
    alpha: float = 0.35,
    gamma: float = 0.92,
    epsilon: float = 0.14,
) -> Tuple[QTable, List[float]]:
    rng = random.Random(seed)
    env = Gridworld()
    q: QTable = defaultdict(float)
    returns: List[float] = []
    for _ in range(episodes):
        state = env.start
        action = epsilon_greedy_action(rng, q, state, env.actions, epsilon)
        total = 0.0
        for _step in range(100):
            next_state, reward, done = env.step(state, action)
            total += reward
            if algorithm == "sarsa":
                next_action = epsilon_greedy_action(rng, q, next_state, env.actions, epsilon)
                target = reward if done else reward + gamma * q[(next_state, next_action)]
                q[(state, action)] += alpha * (target - q[(state, action)])
                state, action = next_state, next_action
            elif algorithm == "q_learning":
                best_next = max(q[(next_state, a)] for a in env.actions)
                target = reward if done else reward + gamma * best_next
                q[(state, action)] += alpha * (target - q[(state, action)])
                state = next_state
                action = epsilon_greedy_action(rng, q, state, env.actions, epsilon)
            else:
                raise ValueError("algorithm must be 'sarsa' or 'q_learning'")
            if done:
                break
        returns.append(total)
    return q, returns


def greedy_grid_policy(q: QTable) -> List[str]:
    env = Gridworld()
    symbols = {0: "↑", 1: "→", 2: "↓", 3: "←"}
    rows: List[str] = []
    rng = random.Random(0)
    for y in range(env.height):
        cells: List[str] = []
        for x in range(env.width):
            state = (x, y)
            if state == env.start:
                cells.append("S")
            elif state == env.goal:
                cells.append("G")
            elif state == env.trap:
                cells.append("X")
            elif state == env.wall:
                cells.append("#")
            else:
                values = [q[(state, a)] for a in env.actions]
                cells.append(symbols[argmax_random_tie(rng, values)])
        rows.append(" ".join(cells))
    return rows


def evaluate_grid_policy(q: QTable, episodes: int = 50) -> Dict[str, float]:
    env = Gridworld()
    rng = random.Random(123)
    total_return = 0.0
    successes = 0
    for _ in range(episodes):
        state = env.start
        ep_return = 0.0
        for _step in range(50):
            values = [q[(state, a)] for a in env.actions]
            action = env.actions[argmax_random_tie(rng, values)]
            state, reward, done = env.step(state, action)
            ep_return += reward
            if done:
                successes += int(state == env.goal)
                break
        total_return += ep_return
    return {
        "average_return": round(total_return / episodes, 3),
        "success_rate": round(successes / episodes, 3),
    }


def gridworld_demo(seed: int = 11) -> Dict[str, object]:
    sarsa_q, sarsa_returns = train_gridworld("sarsa", seed=seed)
    qlearn_q, qlearn_returns = train_gridworld("q_learning", seed=seed)
    return {
        "sarsa": {
            "last_50_average_return": round(sum(sarsa_returns[-50:]) / 50, 3),
            "evaluation": evaluate_grid_policy(sarsa_q),
            "policy": greedy_grid_policy(sarsa_q),
        },
        "q_learning": {
            "last_50_average_return": round(sum(qlearn_returns[-50:]) / 50, 3),
            "evaluation": evaluate_grid_policy(qlearn_q),
            "policy": greedy_grid_policy(qlearn_q),
        },
    }


# ---------------------------------------------------------------------------
# Lab 3: Mini Pong with tabular Q-learning
# ---------------------------------------------------------------------------


@dataclass
class MiniPong:
    height: int = 5
    width: int = 5

    @property
    def actions(self) -> Tuple[Action, ...]:
        # -1 up, 0 stay, +1 down
        return (-1, 0, 1)

    def reset(self, rng: random.Random) -> State:
        # ball starts near left, moves right; dy is -1 or +1
        return (0, rng.randrange(self.height), 1, rng.choice([-1, 1]), self.height // 2)

    def step(self, state: State, action: Action) -> Tuple[State, float, bool]:
        x, y, dx, dy, paddle_y = state
        paddle_y = max(0, min(self.height - 1, paddle_y + action))
        x += dx
        y += dy
        if y < 0:
            y, dy = 1, 1
        elif y >= self.height:
            y, dy = self.height - 2, -1
        if x >= self.width - 1:
            if abs(y - paddle_y) <= 0:
                return (self.width - 2, y, -1, dy, paddle_y), 1.0, True
            return (self.width - 1, y, dx, dy, paddle_y), -1.0, True
        if x <= 0:
            x, dx = 0, 1
        return (x, y, dx, dy, paddle_y), 0.0, False


def train_mini_pong(seed: int = 21, episodes: int = 2500) -> Tuple[QTable, List[float]]:
    rng = random.Random(seed)
    env = MiniPong()
    q: QTable = defaultdict(float)
    alpha, gamma = 0.25, 0.95
    returns: List[float] = []
    for ep in range(episodes):
        epsilon = max(0.03, 0.35 * (1 - ep / episodes))
        state = env.reset(rng)
        total = 0.0
        for _ in range(20):
            action = epsilon_greedy_action(rng, q, state, env.actions, epsilon)
            next_state, reward, done = env.step(state, action)
            best_next = max(q[(next_state, a)] for a in env.actions)
            target = reward if done else reward + gamma * best_next
            q[(state, action)] += alpha * (target - q[(state, action)])
            state = next_state
            total += reward
            if done:
                break
        returns.append(total)
    return q, returns


def evaluate_mini_pong(q: QTable, episodes: int = 300) -> Dict[str, float]:
    rng = random.Random(222)
    env = MiniPong()
    wins = 0
    total = 0.0
    for _ in range(episodes):
        state = env.reset(rng)
        for _ in range(20):
            values = [q[(state, a)] for a in env.actions]
            action = env.actions[argmax_random_tie(rng, values)]
            state, reward, done = env.step(state, action)
            if done:
                wins += int(reward > 0)
                total += reward
                break
    return {"win_rate": round(wins / episodes, 3), "average_return": round(total / episodes, 3)}


def mini_pong_demo(seed: int = 21) -> Dict[str, object]:
    q, returns = train_mini_pong(seed=seed)
    return {
        "last_200_average_return": round(sum(returns[-200:]) / 200, 3),
        "evaluation": evaluate_mini_pong(q),
        "state_example": "(ball_x, ball_y, ball_dx, ball_dy, paddle_y)",
        "actions": {"-1": "up", "0": "stay", "1": "down"},
    }


# ---------------------------------------------------------------------------
# Lab 4: Policy gradient bandit
# ---------------------------------------------------------------------------


def policy_gradient_bandit_demo(seed: int = 31, steps: int = 700) -> Dict[str, object]:
    rng = random.Random(seed)
    reward_means = [0.1, 0.35, 0.8]
    preferences = [0.0, 0.0, 0.0]
    baseline = 0.0
    alpha = 0.08
    for t in range(1, steps + 1):
        probs = softmax(preferences)
        action = weighted_choice(rng, probs)
        reward = rng.gauss(reward_means[action], 0.25)
        baseline += (reward - baseline) / t
        for a in range(3):
            grad = (1.0 if a == action else 0.0) - probs[a]
            preferences[a] += alpha * (reward - baseline) * grad
    return {
        "final_action_probabilities": [round(p, 3) for p in softmax(preferences)],
        "preferences": [round(v, 3) for v in preferences],
        "best_true_action": 2,
    }


# ---------------------------------------------------------------------------
# Lab 5: Soft/SAC-style entropy pricing bandit
# ---------------------------------------------------------------------------


def sample_price_profit(rng: random.Random, price: int) -> float:
    # Synthetic demand curve: Rp22k has best expected profit, with noise.
    expected = 90.0 - 0.9 * (price - 22) ** 2
    return rng.gauss(expected, 3.0)


def soft_pricing_bandit_demo(seed: int = 41, steps: int = 600, temperature: float = 0.7) -> Dict[str, object]:
    rng = random.Random(seed)
    prices = [18, 20, 22, 24]
    q = [0.0 for _ in prices]
    counts = [0 for _ in prices]
    entropy_trace: List[float] = []
    for _ in range(steps):
        probs = softmax(q, temperature=max(temperature, 0.05))
        action_idx = weighted_choice(rng, probs)
        reward = sample_price_profit(rng, prices[action_idx])
        counts[action_idx] += 1
        q[action_idx] += (reward - q[action_idx]) / counts[action_idx]
        entropy = -sum(p * math.log(max(p, 1e-12)) for p in probs)
        entropy_trace.append(entropy)
    final_probs = softmax(q, temperature=max(temperature, 0.05))
    return {
        "prices_thousand_rupiah": prices,
        "estimated_profit": [round(v, 3) for v in q],
        "selection_counts": counts,
        "final_policy_probabilities": [round(p, 3) for p in final_probs],
        "average_entropy_last_100": round(sum(entropy_trace[-100:]) / 100, 3),
        "note": "SAC-style intuition only: reward plus entropy, not full SAC.",
    }


# ---------------------------------------------------------------------------
# Lab 6: Toy RLHF preference model
# ---------------------------------------------------------------------------


def answer_features(answer: str) -> List[float]:
    lower = answer.lower()
    length_score = min(len(answer) / 160.0, 1.0)
    has_caveat = 1.0 if any(x in lower for x in ["tidak tahu", "perlu cek", "berdasarkan konteks"]) else 0.0
    has_steps = 1.0 if any(x in lower for x in ["langkah", "pertama", "kedua", "contoh"]) else 0.0
    overclaim = 1.0 if any(x in lower for x in ["pasti", "selalu", "dijamin"]) else 0.0
    return [1.0, length_score, has_caveat, has_steps, overclaim]


def dot(a: Sequence[float], b: Sequence[float]) -> float:
    return sum(x * y for x, y in zip(a, b))


def toy_rlhf_preference_demo(seed: int = 51, epochs: int = 200) -> Dict[str, object]:
    rng = random.Random(seed)
    pairs = [
        (
            "Diskon 20% pasti selalu terbaik untuk menaikkan omzet.",
            "Diskon perlu diuji. Pertama cek margin, lalu bandingkan promo kecil dan bundling.",
            1,
        ),
        (
            "Q-learning dijamin cocok untuk semua masalah bisnis.",
            "Q-learning cocok untuk sebagian MDP kecil; untuk bisnis nyata perlu guardrail dan evaluasi.",
            1,
        ),
        (
            "Saya tidak tahu data pastinya; berdasarkan konteks, mulai dari bandit lebih aman.",
            "Pakai deep RL saja agar modern.",
            0,
        ),
        (
            "Contoh langkah: definisikan state, action, reward, lalu uji baseline sederhana.",
            "RL itu membuat AI pintar sendiri.",
            0,
        ),
    ]
    weights = [0.0 for _ in range(5)]
    lr = 0.08
    for _ in range(epochs):
        rng.shuffle(pairs)
        for a, b, preferred_index in pairs:
            fa, fb = answer_features(a), answer_features(b)
            score_a, score_b = dot(weights, fa), dot(weights, fb)
            prob_a = 1.0 / (1.0 + math.exp(score_b - score_a))
            target_a = 1.0 if preferred_index == 0 else 0.0
            error = target_a - prob_a
            for i in range(len(weights)):
                weights[i] += lr * error * (fa[i] - fb[i])
    scored = []
    for a, b, _ in pairs:
        for ans in [a, b]:
            scored.append((round(dot(weights, answer_features(ans)), 3), ans))
    scored.sort(reverse=True)
    return {
        "feature_names": ["bias", "length", "caveat", "steps", "overclaim"],
        "learned_weights": [round(w, 3) for w in weights],
        "top_scored_answers": scored[:3],
        "warning": "Toy reward model: useful for intuition, not a factuality or safety guarantee.",
    }


# ---------------------------------------------------------------------------
# Orchestration, tests, and CLI
# ---------------------------------------------------------------------------


def run_demo(seed: int = 123) -> Dict[str, object]:
    return {
        "metadata": {
            "chapter": "Bab 12 — Reinforcement Learning",
            "seed": seed,
            "note": "All demos are educational, small, and offline.",
        },
        "sales_bandit": sales_bandit_demo(seed + 1),
        "gridworld_sarsa_vs_q_learning": gridworld_demo(seed + 2),
        "mini_pong": mini_pong_demo(seed + 3),
        "policy_gradient_bandit": policy_gradient_bandit_demo(seed + 4),
        "soft_sac_style_pricing": soft_pricing_bandit_demo(seed + 5),
        "toy_rlhf_preference_model": toy_rlhf_preference_demo(seed + 6),
    }


def save_results(results: Dict[str, object], output_path: str) -> None:
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
        f.write("\n")


def self_test() -> None:
    probs = softmax([1.0, 2.0, 3.0])
    assert abs(sum(probs) - 1.0) < 1e-9
    assert probs[2] > probs[1] > probs[0]

    bandit = sales_bandit_demo(7)
    assert "epsilon_greedy" in bandit and "ucb" in bandit
    assert sum(bandit["ucb"]["action_counts"]) == 500

    grid = gridworld_demo(11)
    assert grid["q_learning"]["evaluation"]["success_rate"] >= 0.8
    assert len(grid["q_learning"]["policy"]) == 5

    pong = mini_pong_demo(21)
    assert pong["evaluation"]["win_rate"] >= 0.65

    pg = policy_gradient_bandit_demo(31)
    assert pg["final_action_probabilities"][2] > 0.5

    soft = soft_pricing_bandit_demo(41)
    assert len(soft["final_policy_probabilities"]) == 4

    rlhf = toy_rlhf_preference_demo(51)
    assert rlhf["learned_weights"][4] < 0.0  # overclaim should be penalized


def main(argv: Sequence[str] | None = None) -> None:
    parser = argparse.ArgumentParser(description="Bab 12 RL playground")
    parser.add_argument("--self-test", action="store_true", help="run deterministic checks")
    parser.add_argument("--seed", type=int, default=123, help="random seed")
    parser.add_argument(
        "--output",
        default=os.path.join("outputs", "bab12_rl_results.json"),
        help="output JSON path",
    )
    args = parser.parse_args(argv)

    if args.self_test:
        self_test()
        print("self-test passed")
        return

    results = run_demo(args.seed)
    save_results(results, args.output)
    print(json.dumps(results, ensure_ascii=False, indent=2))
    print(f"\nSaved results to {args.output}")


if __name__ == "__main__":
    main()
