#!/usr/bin/env python3
"""Bab 11 generative AI playground.

Demo ringan tanpa API luar:
- tokenisasi sederhana
- embedding bag-of-words dan cosine retrieval
- self-attention mini
- bigram next-token decoding dengan temperature/top-k
- mini RAG berbasis dokumen lokal
- pemeriksaan klaim sederhana terhadap sumber

Jalankan:
  python3 generative_ai_playground.py --self-test
  python3 generative_ai_playground.py
"""
from __future__ import annotations

import argparse
import json
import math
import random
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Iterable

import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

DOCS = [
    {
        "id": "S1",
        "title": "SOP Pinjaman Koperasi",
        "text": "Anggota koperasi dapat mengajukan pinjaman setelah aktif minimal enam bulan dan memiliki simpanan pokok lunas. Pengajuan wajib melampirkan KTP, formulir, dan rekomendasi ketua unit.",
    },
    {
        "id": "S2",
        "title": "Panduan UMKM Digital",
        "text": "Pelaku UMKM sebaiknya mencatat stok, harga, dan transaksi harian. Foto produk yang jelas dan deskripsi singkat membantu pelanggan memahami manfaat produk.",
    },
    {
        "id": "S3",
        "title": "Aturan Laboratorium Sekolah",
        "text": "Siswa wajib memakai kacamata pelindung saat praktikum kimia. Makanan dan minuman tidak boleh dibawa ke area praktikum. Guru harus memeriksa alat sebelum kegiatan dimulai.",
    },
    {
        "id": "S4",
        "title": "Kebijakan AI Internal",
        "text": "Data pribadi seperti NIK, rekam medis, nilai siswa, dan rahasia perusahaan tidak boleh dimasukkan ke layanan AI publik tanpa izin tertulis dan proses anonimisasi.",
    },
]

MINI_CORPUS = """
AI membantu guru membuat latihan dan ringkasan.
Guru memeriksa jawaban AI sebelum diberikan ke siswa.
UMKM memakai AI untuk deskripsi produk dan balasan pelanggan.
Dokumen resmi harus menjadi sumber jawaban chatbot.
Jika sumber tidak memuat jawaban, sistem harus mengatakan tidak ditemukan.
"""


def tokenize(text: str) -> list[str]:
    """Tokenisasi sederhana untuk demo; bukan tokenizer LLM produksi."""
    return re.findall(r"[a-zA-Z0-9_]+", text.lower())


def build_vocab(texts: Iterable[str]) -> dict[str, int]:
    vocab = sorted({tok for text in texts for tok in tokenize(text)})
    return {tok: i for i, tok in enumerate(vocab)}


def bow_vector(text: str, vocab: dict[str, int]) -> np.ndarray:
    vec = np.zeros(len(vocab), dtype=float)
    for tok in tokenize(text):
        if tok in vocab:
            vec[vocab[tok]] += 1.0
    norm = np.linalg.norm(vec)
    return vec / norm if norm else vec


def cosine(a: np.ndarray, b: np.ndarray) -> float:
    denom = float(np.linalg.norm(a) * np.linalg.norm(b))
    return float(np.dot(a, b) / denom) if denom else 0.0


def retrieve(query: str, docs: list[dict[str, str]], top_k: int = 2) -> list[dict[str, object]]:
    vocab = build_vocab([query] + [d["text"] for d in docs])
    qv = bow_vector(query, vocab)
    scored = []
    for doc in docs:
        score = cosine(qv, bow_vector(doc["text"], vocab))
        scored.append({**doc, "score": round(score, 4)})
    return sorted(scored, key=lambda x: x["score"], reverse=True)[:top_k]


def mini_rag_answer(query: str, docs: list[dict[str, str]] = DOCS) -> dict[str, object]:
    hits = retrieve(query, docs, top_k=2)
    if not hits or hits[0]["score"] <= 0:
        return {
            "question": query,
            "answer": "Tidak ditemukan di dokumen.",
            "sources": [],
            "retrieved": hits,
        }
    best = hits[0]
    answer = (
        f"Berdasarkan {best['id']} ({best['title']}), {best['text']} "
        "Periksa kembali dokumen asli sebelum mengambil keputusan penting."
    )
    return {
        "question": query,
        "answer": answer,
        "sources": [best["id"]],
        "retrieved": hits,
    }


def softmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
    if temperature <= 0:
        raise ValueError("temperature must be positive")
    z = x / temperature
    z = z - np.max(z)
    exp = np.exp(z)
    return exp / exp.sum()


def attention_demo() -> dict[str, object]:
    # Tiga token imajiner: "sari", "rina", "dia" dengan dimensi kecil.
    tokens = ["sari", "rina", "dia"]
    Q = np.array([[1.0, 0.2], [0.3, 1.0], [0.9, 0.4]])
    K = np.array([[1.0, 0.1], [0.2, 1.0], [0.7, 0.5]])
    V = np.array([[1.0, 0.0], [0.0, 1.0], [0.6, 0.6]])
    scores = Q @ K.T / math.sqrt(K.shape[1])
    weights = np.vstack([softmax(row) for row in scores])
    output = weights @ V
    return {
        "tokens": tokens,
        "scores": np.round(scores, 4).tolist(),
        "attention_weights": np.round(weights, 4).tolist(),
        "contextual_vectors": np.round(output, 4).tolist(),
    }


def train_bigram(corpus: str) -> dict[str, Counter]:
    model: dict[str, Counter] = defaultdict(Counter)
    for line in corpus.splitlines():
        toks = ["<bos>"] + tokenize(line) + ["<eos>"]
        for a, b in zip(toks, toks[1:]):
            model[a][b] += 1
    return model


def sample_next(counter: Counter, temperature: float = 1.0, top_k: int | None = None) -> tuple[str, dict[str, float]]:
    items = counter.most_common()
    if top_k is not None:
        items = items[:top_k]
    words = [w for w, _ in items]
    counts = np.array([c for _, c in items], dtype=float)
    probs = softmax(np.log(counts + 1e-9), temperature=temperature)
    choice = random.choices(words, weights=probs, k=1)[0]
    return choice, {w: round(float(p), 4) for w, p in zip(words, probs)}


def generate_bigram(prompt: str, max_tokens: int = 14, temperature: float = 0.8, top_k: int = 5) -> dict[str, object]:
    model = train_bigram(MINI_CORPUS)
    toks = tokenize(prompt)
    current = toks[-1] if toks else "<bos>"
    generated = toks[:]
    trace = []
    for _ in range(max_tokens):
        counter = model.get(current) or model["<bos>"]
        nxt, probs = sample_next(counter, temperature=temperature, top_k=top_k)
        trace.append({"current": current, "next": nxt, "candidates": probs})
        if nxt == "<eos>":
            break
        generated.append(nxt)
        current = nxt
    return {"prompt": prompt, "text": " ".join(generated), "trace": trace}


def unsupported_claims(answer: str, source_texts: list[str]) -> list[str]:
    """Heuristik sederhana: klaim dianggap didukung jika overlap token kontennya cukup.

    Ini bukan fact-checker produksi. Stopword umum dibuang agar kata seperti
    "wajib" atau "memiliki" tidak membuat klaim palsu tampak didukung.
    """
    stop = {
        "yang", "dan", "atau", "untuk", "dengan", "dari", "pada", "harus",
        "wajib", "dapat", "memiliki", "minimal", "seperti", "adalah",
    }
    source_tokens = set(t for t in tokenize(" ".join(source_texts)) if t not in stop and len(t) > 3)
    claims = [c.strip() for c in re.split(r"[.!?]+", answer) if c.strip()]
    unsupported = []
    for claim in claims:
        toks = [t for t in tokenize(claim) if t not in stop and len(t) > 3]
        if toks and sum(t in source_tokens for t in toks) / len(toks) < 0.5:
            unsupported.append(claim)
    return unsupported


def run_demo() -> dict[str, object]:
    query = "Apa syarat anggota mengajukan pinjaman koperasi?"
    rag = mini_rag_answer(query)
    hallucinated = "Anggota wajib memiliki saldo minimal sepuluh juta dan surat dokter."
    return {
        "seed": SEED,
        "tokenization": tokenize("Beras premium 5 kg untuk UMKM Bandung."),
        "retrieval": retrieve(query, DOCS, top_k=4),
        "mini_rag": rag,
        "attention": attention_demo(),
        "bigram_low_temperature": generate_bigram("guru", temperature=0.4, top_k=3),
        "bigram_higher_temperature": generate_bigram("guru", temperature=1.4, top_k=5),
        "unsupported_claim_check": {
            "answer": hallucinated,
            "unsupported_claims": unsupported_claims(hallucinated, [d["text"] for d in DOCS]),
        },
    }


def self_test() -> None:
    assert tokenize("KTP dan NIK!") == ["ktp", "dan", "nik"]
    assert abs(cosine(np.array([1, 0]), np.array([0, 1]))) < 1e-9
    assert retrieve("pinjaman koperasi", DOCS, 1)[0]["id"] == "S1"
    attn = attention_demo()["attention_weights"]
    for row in attn:
        assert abs(sum(row) - 1.0) < 1e-3
    rag = mini_rag_answer("Bolehkah memasukkan NIK ke AI publik?")
    assert rag["sources"] == ["S4"]
    bad = "Anggota wajib memiliki saldo minimal sepuluh juta dan surat dokter."
    assert unsupported_claims(bad, [d["text"] for d in DOCS])
    gen = generate_bigram("guru", max_tokens=5)
    assert "text" in gen and gen["trace"]


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--self-test", action="store_true", help="run lightweight assertions")
    parser.add_argument("--output", default="outputs/bab11_demo_results.json")
    args = parser.parse_args()

    if args.self_test:
        self_test()
        print("self-test passed")
        return 0

    results = run_demo()
    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8")
    print(json.dumps(results, indent=2, ensure_ascii=False))
    print(f"\nOutput saved to {out_path.resolve()}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
