#!/usr/bin/env python3
"""Bab 10 Deep Learning Playground.

Ringan dan sengaja ditulis dengan NumPy agar pembaca bisa mengetik manual
sebagian fungsi dari ebook. Bisa dijalankan di terminal, VS Code, Jupyter,
Google Colab, dan Kaggle Notebook.

Jalankan:
  python3 deep_learning_playground.py
  python3 deep_learning_playground.py --self-test
"""
from __future__ import annotations

import argparse
import contextlib
import io
import json
import math
from pathlib import Path

import numpy as np


OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"


def conv2d_valid(x: np.ndarray, kernel: np.ndarray) -> np.ndarray:
    """Convolution/correlation 2D valid untuk satu kanal.

    Deep learning framework biasanya memakai cross-correlation, bukan kernel
    flipped convolution klasik. Untuk belajar CNN, operasi ini cukup.
    """
    x = np.asarray(x, dtype=float)
    kernel = np.asarray(kernel, dtype=float)
    kh, kw = kernel.shape
    oh, ow = x.shape[0] - kh + 1, x.shape[1] - kw + 1
    out = np.zeros((oh, ow), dtype=float)
    for i in range(oh):
        for j in range(ow):
            patch = x[i : i + kh, j : j + kw]
            out[i, j] = np.sum(patch * kernel)
    return out


def max_pool2d(x: np.ndarray, size: int = 2, stride: int = 2) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    oh = (x.shape[0] - size) // stride + 1
    ow = (x.shape[1] - size) // stride + 1
    out = np.zeros((oh, ow), dtype=float)
    for i in range(oh):
        for j in range(ow):
            patch = x[i * stride : i * stride + size, j * stride : j * stride + size]
            out[i, j] = np.max(patch)
    return out


def conv_output_size(input_size: int, kernel: int, padding: int = 0, stride: int = 1) -> int:
    return math.floor((input_size - kernel + 2 * padding) / stride) + 1


def conv_params(kernel: int, in_channels: int, out_channels: int, bias: bool = True) -> int:
    return kernel * kernel * in_channels * out_channels + (out_channels if bias else 0)


def conv_multadds(kernel: int, in_channels: int, out_channels: int, feature_size: int) -> int:
    return kernel * kernel * in_channels * out_channels * feature_size * feature_size


def depthwise_separable_multadds(kernel: int, in_channels: int, out_channels: int, feature_size: int) -> int:
    depthwise = kernel * kernel * in_channels * feature_size * feature_size
    pointwise = in_channels * out_channels * feature_size * feature_size
    return depthwise + pointwise


def residual_block(x: np.ndarray, weight: float = 0.25) -> np.ndarray:
    """Blok residual mainan: H(x)=F(x)+x dengan F(x)=ReLU(weight*x)."""
    x = np.asarray(x, dtype=float)
    fx = np.maximum(0, weight * x)
    return fx + x


def rnn_step(x_t: np.ndarray, h_prev: np.ndarray, wx: np.ndarray, wh: np.ndarray, b: np.ndarray) -> np.ndarray:
    return np.tanh(wx @ x_t + wh @ h_prev + b)


def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1 / (1 + np.exp(-x))


def lstm_step_scalar(c_prev: float, forget: float, input_gate: float, candidate: float, output_gate: float) -> tuple[float, float]:
    c_t = forget * c_prev + input_gate * candidate
    h_t = output_gate * math.tanh(c_t)
    return c_t, h_t


def softmax(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    z = x - np.max(x, axis=-1, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=-1, keepdims=True)


def self_attention(q: np.ndarray, k: np.ndarray, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Scaled dot-product attention.

    q, k, v shape: (tokens, dim)
    """
    dk = q.shape[-1]
    scores = q @ k.T / math.sqrt(dk)
    weights = softmax(scores)
    return weights @ v, weights


def binarize_weights(w: np.ndarray) -> np.ndarray:
    return np.where(np.asarray(w) >= 0, 1, -1)


def synthetic_loss_curves() -> dict[str, list[float]]:
    epochs = np.arange(1, 11)
    train = np.exp(-epochs / 3) + 0.05
    val = np.array([0.95, 0.70, 0.55, 0.48, 0.46, 0.50, 0.58, 0.72, 0.90, 1.12])
    return {"epoch": epochs.tolist(), "train_loss": train.round(4).tolist(), "val_loss": val.tolist()}


def maybe_plot_loss(curves: dict[str, list[float]]) -> str:
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    out = OUTPUT_DIR / "bab10_loss_curves.png"
    try:
        # Beberapa environment belajar mencampur NumPy baru dengan Matplotlib
        # lama. Redirect stderr agar pesan binary-compatibility yang panjang
        # tidak mengganggu pembaca; perhitungan inti tetap berjalan.
        with contextlib.redirect_stderr(io.StringIO()):
            import matplotlib.pyplot as plt
    except BaseException as exc:  # pragma: no cover - optional dependency
        return f"matplotlib tidak tersedia/kompatibel ({type(exc).__name__}: {exc}); plot dilewati"
    plt.figure(figsize=(7, 4.2))
    plt.plot(curves["epoch"], curves["train_loss"], marker="o", label="train loss")
    plt.plot(curves["epoch"], curves["val_loss"], marker="s", label="validation loss")
    plt.title("Bab 10: contoh diagnosis overfitting")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out, dpi=160)
    plt.close()
    return str(out)


def run_demo() -> dict[str, object]:
    x = np.array([
        [1, 2, 0, 1],
        [3, 1, 2, 2],
        [0, 1, 3, 1],
        [2, 2, 1, 0],
    ], dtype=float)
    kernel = np.array([[1, 0], [0, -1]], dtype=float)
    conv = conv2d_valid(x, kernel)
    pooled = max_pool2d(x)

    standard = conv_multadds(kernel=3, in_channels=32, out_channels=64, feature_size=64)
    separable = depthwise_separable_multadds(kernel=3, in_channels=32, out_channels=64, feature_size=64)

    rnn_h = rnn_step(
        x_t=np.array([2.0]),
        h_prev=np.array([0.5]),
        wx=np.array([[0.4]]),
        wh=np.array([[0.6]]),
        b=np.array([-0.1]),
    )
    c_t, h_t = lstm_step_scalar(c_prev=2.0, forget=0.8, input_gate=0.3, candidate=0.5, output_gate=0.9)

    q = np.array([[1.0, 0.0], [0.2, 0.8]])
    k = np.array([[1.0, 0.0], [0.0, 1.0]])
    v = np.array([[10.0, 0.0], [0.0, 5.0]])
    context, weights = self_attention(q, k, v)

    curves = synthetic_loss_curves()
    plot_status = maybe_plot_loss(curves)

    result = {
        "conv_valid": conv.tolist(),
        "max_pool": pooled.tolist(),
        "conv_output_32_k5_p2_s1": conv_output_size(32, 5, 2, 1),
        "conv_params_3x3_3_to_64": conv_params(3, 3, 64),
        "standard_conv_multadds": standard,
        "depthwise_separable_multadds": separable,
        "separable_saving_ratio": round(standard / separable, 3),
        "residual_example": residual_block(np.array([2.0, -1.0, 3.0])).tolist(),
        "rnn_hidden": rnn_h.round(6).tolist(),
        "lstm_cell_hidden": [round(c_t, 6), round(h_t, 6)],
        "attention_weights": weights.round(4).tolist(),
        "attention_context": context.round(4).tolist(),
        "binary_weights": binarize_weights(np.array([-0.7, 0.2, 1.3])).tolist(),
        "loss_curves": curves,
        "plot_status": plot_status,
    }
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    (OUTPUT_DIR / "bab10_demo_results.json").write_text(json.dumps(result, indent=2), encoding="utf-8")
    return result


def self_test() -> None:
    x = np.array([[1, 2], [3, 4]])
    k = np.array([[1, 0], [0, -1]])
    assert conv2d_valid(x, k).shape == (1, 1)
    assert conv2d_valid(x, k)[0, 0] == -3
    assert conv_output_size(28, 3, 1, 1) == 28
    assert conv_params(3, 1, 8) == 80
    assert max_pool2d(np.array([[1, 9], [3, 4]]))[0, 0] == 9
    standard = conv_multadds(3, 32, 64, 64)
    sep = depthwise_separable_multadds(3, 32, 64, 64)
    assert standard > sep
    c_t, h_t = lstm_step_scalar(4, 0.25, 0.5, 2, 1.0)
    assert abs(c_t - 2.0) < 1e-9
    assert -1 <= h_t <= 1
    context, weights = self_attention(np.eye(2), np.eye(2), np.eye(2))
    assert context.shape == (2, 2)
    assert weights.shape == (2, 2)
    assert binarize_weights(np.array([-2, 0, 3])).tolist() == [-1, 1, 1]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--self-test", action="store_true", help="jalankan cek cepat fungsi inti")
    args = parser.parse_args()
    if args.self_test:
        self_test()
        print("Self-test Bab 10 berhasil")
        return
    result = run_demo()
    print("=== Bab 10 Deep Learning Playground ===")
    print(json.dumps(result, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()
