#!/usr/bin/env python3
"""Bab 03B — Data exploration, visualization, and data quality lab.

Standard-library first so it runs in local terminal, VS Code, Jupyter,
Google Colab, and Kaggle. The chapter also explains matplotlib/seaborn;
this script creates SVG plots without external dependencies as a portable
fallback.
"""

from __future__ import annotations

import csv
import json
import math
import random
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from statistics import median
from typing import Iterable, Optional, Sequence


@dataclass(frozen=True)
class RawRow:
    customer_id: str
    segment: str
    payment: str
    device: str
    visits: Optional[float]
    spend: Optional[float]
    returned: Optional[int]
    date: str


RAW_DATA = [
    RawRow("C001", "rutin", "QRIS", "mobile", 12, 58, 0, "2026-01-01"),
    RawRow("C002", "rutin", "QRIS", "mobile", 10, 50, 0, "2026-01-02"),
    RawRow("C003", "baru", "Cash", "mobile", 2, 15, 1, "2026-01-03"),
    RawRow("C004", "baru", "cash", "desktop", 3, 17, 0, "2026-01-04"),
    RawRow("C005", "vip", "Kartu", "desktop", 1, 170, 0, "2026-01-05"),
    RawRow("C006", "rutin", "QRIS", "mobile", 9, None, 0, "2026-01-06"),
    RawRow("C007", "promo", "QRIS", "tablet", -1, 22, 1, "2026-01-07"),
    RawRow("C008", "promo", "Cash", "mobile", 4, 25, 0, "2026-01-08"),
    RawRow("C009", "rutin", "QRIS", "mobile", 11, 55, 0, "2026-01-09"),
    RawRow("C001", "rutin", "QRIS", "mobile", 12, 58, 0, "2026-01-01"),  # duplicate
    RawRow("C010", "baru", "Kartu", "desktop", 2, 19, 1, "2026-01-10"),
    RawRow("C011", "vip", "QRIS", "mobile", 5, 120, 0, "2026-01-11"),
]


def mean(values: Iterable[float]) -> float:
    values = list(values)
    return sum(values) / len(values)


def std(values: Iterable[float]) -> float:
    values = list(values)
    mu = mean(values)
    return math.sqrt(sum((x - mu) ** 2 for x in values) / len(values))


def quantile(values: Sequence[float], q: float) -> float:
    values = sorted(values)
    pos = (len(values) - 1) * q
    lo = math.floor(pos)
    hi = math.ceil(pos)
    if lo == hi:
        return values[lo]
    return values[lo] * (hi - pos) + values[hi] * (pos - lo)


def audit(rows: Sequence[RawRow]) -> dict:
    total_cells = len(rows) * 8
    missing = 0
    for row in rows:
        missing += sum(v is None or v == "" for v in row.__dict__.values())
    duplicate_rows = len(rows) - len(set(rows))
    invalid_visits = sum(row.visits is not None and row.visits < 0 for row in rows)
    payment_values = Counter(row.payment for row in rows)
    return {
        "rows": len(rows),
        "columns": 8,
        "missing_cells": missing,
        "missing_rate": missing / total_cells,
        "duplicate_rows": duplicate_rows,
        "invalid_visits": invalid_visits,
        "payment_raw_values": dict(payment_values),
    }


def clean(rows: Sequence[RawRow]) -> tuple[list[RawRow], list[str]]:
    log: list[str] = []
    seen = set()
    deduped: list[RawRow] = []
    for row in rows:
        if row in seen:
            log.append(f"hapus duplikat identik: {row.customer_id}")
            continue
        seen.add(row)
        deduped.append(row)

    spend_values = [row.spend for row in deduped if row.spend is not None]
    spend_fill = float(median(spend_values))
    log.append(f"imputasi spend kosong dengan median={spend_fill:.2f}")

    cleaned: list[RawRow] = []
    for row in deduped:
        payment = row.payment.strip().upper()
        if payment == "CASH":
            payment = "Cash"
        elif payment == "QRIS":
            payment = "QRIS"
        elif payment in {"KARTU", "CARD"}:
            payment = "Kartu"
        if row.visits is None or row.visits < 0:
            log.append(f"buang {row.customer_id}: visits invalid ({row.visits})")
            continue
        cleaned.append(
            RawRow(
                row.customer_id,
                row.segment.strip().lower(),
                payment,
                row.device.strip().lower(),
                row.visits,
                row.spend if row.spend is not None else spend_fill,
                row.returned,
                row.date,
            )
        )
    return cleaned, log


def z_scores(values: Sequence[float]) -> list[float]:
    mu = mean(values)
    sigma = std(values) or 1.0
    return [(x - mu) / sigma for x in values]


def linear_regression(xs: Sequence[float], ys: Sequence[float]) -> tuple[float, float]:
    xbar = mean(xs)
    ybar = mean(ys)
    numerator = sum((x - xbar) * (y - ybar) for x, y in zip(xs, ys))
    denominator = sum((x - xbar) ** 2 for x in xs) or 1.0
    w = numerator / denominator
    b = ybar - w * xbar
    return w, b


def split_ids(rows: Sequence[RawRow], seed: int = 42) -> dict:
    ids = [row.customer_id for row in rows]
    rng = random.Random(seed)
    ids = ids[:]
    rng.shuffle(ids)
    n = len(ids)
    train_end = int(n * 0.7)
    valid_end = int(n * 0.85)
    return {
        "seed": seed,
        "strategy": "entity-level random split after cleaning; no duplicate customer rows",
        "train": ids[:train_end],
        "validation": ids[train_end:valid_end],
        "test": ids[valid_end:],
        "leakage_checks": [
            "cleaning median is computed on cleaned dataset only for teaching; in production fit on train only",
            "customer_id is not a model feature",
            "future outcome columns are not included in features",
        ],
    }


def write_csv(rows: Sequence[RawRow], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].__dict__.keys()), lineterminator="\n")
        writer.writeheader()
        for row in rows:
            writer.writerow(row.__dict__)


def scale(v: float, lo: float, hi: float, out_lo: float, out_hi: float) -> float:
    if hi == lo:
        return (out_lo + out_hi) / 2
    return out_lo + (v - lo) * (out_hi - out_lo) / (hi - lo)


def svg_shell(title: str, subtitle: str, body: str) -> str:
    return f'''<svg xmlns="http://www.w3.org/2000/svg" width="760" height="460" viewBox="0 0 760 460" role="img" aria-label="{title}">
<rect width="760" height="460" fill="#f8fafc"/>
<rect x="24" y="22" width="712" height="416" rx="22" fill="#ffffff" stroke="#cbd5e1" stroke-width="2"/>
<text x="50" y="60" font-family="Arial" font-size="24" font-weight="700" fill="#0f172a">{title}</text>
<text x="50" y="86" font-family="Arial" font-size="15" fill="#475569">{subtitle}</text>
{body}
</svg>'''


def write_bar(counter: Counter, path: Path, title: str) -> None:
    max_count = max(counter.values()) or 1
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    for i, (label, count) in enumerate(counter.most_common()):
        h = scale(count, 0, max_count, 0, 220)
        x = 120 + i * 120
        body.append(f'<rect x="{x}" y="{380-h:.1f}" width="70" height="{h:.1f}" fill="#2563eb"/><text x="{x}" y="405" font-size="13">{label}</text><text x="{x+22}" y="{370-h:.1f}" font-size="13">{count}</text>')
    path.write_text(svg_shell(title, "bar plot cocok untuk membandingkan kategori", "\n".join(body)), encoding="utf-8")


def write_hist(values: Sequence[float], path: Path) -> None:
    bins = 5
    lo, hi = min(values), max(values)
    width = (hi - lo) / bins or 1
    counts = [0] * bins
    for v in values:
        counts[min(bins - 1, int((v - lo) / width))] += 1
    max_count = max(counts) or 1
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    for i, count in enumerate(counts):
        h = scale(count, 0, max_count, 0, 220)
        x = 120 + i * 90
        body.append(f'<rect x="{x}" y="{380-h:.1f}" width="64" height="{h:.1f}" fill="#60a5fa"/><text x="{x+20}" y="{370-h:.1f}" font-size="13">{count}</text>')
    path.write_text(svg_shell("Histogram Belanja", "histogram cocok untuk distribusi numerik", "\n".join(body)), encoding="utf-8")


def write_scatter(rows: Sequence[RawRow], path: Path) -> None:
    xs = [r.visits for r in rows if r.visits is not None]
    ys = [r.spend for r in rows if r.spend is not None]
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    for row in rows:
        x = scale(row.visits or 0, min(xs), max(xs), 105, 670)
        y = scale(row.spend or 0, min(ys), max(ys), 370, 130)
        body.append(f'<circle cx="{x:.1f}" cy="{y:.1f}" r="8" fill="#16a34a"/><text x="{x+8:.1f}" y="{y-7:.1f}" font-size="11">{row.customer_id}</text>')
    path.write_text(svg_shell("Scatter Visits vs Spend", "scatter cocok untuk dua variabel numerik", "\n".join(body)), encoding="utf-8")


def write_box(rows: Sequence[RawRow], path: Path) -> None:
    groups: dict[str, list[float]] = defaultdict(list)
    for row in rows:
        groups[row.segment].append(row.spend or 0)
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    for i, (seg, vals) in enumerate(sorted(groups.items())):
        vals = sorted(vals)
        q1, q2, q3 = quantile(vals, 0.25), quantile(vals, 0.5), quantile(vals, 0.75)
        ymin, ymax = min(v for row in rows for v in [row.spend or 0]), max(v for row in rows for v in [row.spend or 0])
        x = 150 + i * 120
        y1, y2, y3 = [scale(v, ymin, ymax, 370, 130) for v in (q1, q2, q3)]
        body.append(f'<rect x="{x}" y="{y3:.1f}" width="60" height="{y1-y3:.1f}" fill="#dbeafe" stroke="#2563eb"/><path d="M{x} {y2:.1f} L{x+60} {y2:.1f}" stroke="#ef4444" stroke-width="3"/><text x="{x}" y="405" font-size="12">{seg}</text>')
    path.write_text(svg_shell("Box Plot Spend by Segment", "box plot menampilkan median dan kuartil", "\n".join(body)), encoding="utf-8")


def write_line(rows: Sequence[RawRow], path: Path) -> None:
    rows = sorted(rows, key=lambda r: r.date)
    ys = [r.spend or 0 for r in rows]
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    pts = []
    for i, row in enumerate(rows):
        x = scale(i, 0, len(rows) - 1, 110, 670)
        y = scale(row.spend or 0, min(ys), max(ys), 370, 130)
        pts.append(f'{x:.1f},{y:.1f}')
        body.append(f'<circle cx="{x:.1f}" cy="{y:.1f}" r="5" fill="#7c3aed"/>')
    body.append(f'<polyline points="{" ".join(pts)}" fill="none" stroke="#7c3aed" stroke-width="4"/>')
    path.write_text(svg_shell("Line Plot Harian", "line plot cocok untuk urutan waktu", "\n".join(body)), encoding="utf-8")


def write_outlier(rows: Sequence[RawRow], path: Path) -> None:
    spends = [r.spend or 0 for r in rows]
    zs = z_scores(spends)
    body = ['<path d="M90 380 L690 380 M90 380 L90 120" stroke="#334155"/>']
    for i, (row, z) in enumerate(zip(rows, zs)):
        x = scale(i, 0, len(rows) - 1, 110, 670)
        y = scale(z, min(zs), max(zs), 370, 130)
        color = '#ef4444' if abs(z) > 1.5 else '#10b981'
        body.append(f'<circle cx="{x:.1f}" cy="{y:.1f}" r="8" fill="{color}"/><text x="{x-10:.1f}" y="405" font-size="10">{row.customer_id}</text><text x="{x-12:.1f}" y="{y-10:.1f}" font-size="11">{z:.1f}</text>')
    path.write_text(svg_shell("Z-score Outlier", "visualisasi kandidat outlier belanja", "\n".join(body)), encoding="utf-8")


def write_pie(counter: Counter, path: Path) -> None:
    total = sum(counter.values())
    # Keep this intentionally simple: labels + color blocks instead of true arcs for portability.
    colors = ['#60a5fa', '#34d399', '#fbbf24', '#f87171']
    body = ['<circle cx="250" cy="250" r="95" fill="#dbeafe"/>']
    for i, (label, count) in enumerate(counter.most_common()):
        pct = count / total * 100
        body.append(f'<rect x="430" y="{160+i*38}" width="24" height="24" fill="{colors[i%len(colors)]}"/><text x="465" y="{178+i*38}" font-size="14">{label}: {pct:.1f}%</text>')
    path.write_text(svg_shell("Pie-style Device Share", "pie cocok untuk sedikit kategori proporsi", "\n".join(body)), encoding="utf-8")


def main() -> None:
    script_dir = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
    out = script_dir / "outputs"
    out.mkdir(parents=True, exist_ok=True)

    audit_report = audit(RAW_DATA)
    cleaned, cleaning_log = clean(RAW_DATA)
    spends = [row.spend or 0 for row in cleaned]
    visits = [row.visits or 0 for row in cleaned]
    w, b = linear_regression(visits, spends)
    residuals = [y - (w * x + b) for x, y in zip(visits, spends)]
    q1, q3 = quantile(spends, 0.25), quantile(spends, 0.75)
    iqr = q3 - q1
    upper = q3 + 1.5 * iqr

    write_csv(cleaned, out / "cleaned_customers.csv")
    (out / "split_manifest.json").write_text(json.dumps(split_ids(cleaned), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")

    write_bar(Counter(r.payment for r in cleaned), out / "bar_payment.svg", "Bar Plot Payment")
    write_pie(Counter(r.device for r in cleaned), out / "pie_device.svg")
    write_hist(spends, out / "hist_spend.svg")
    write_scatter(cleaned, out / "scatter_visit_spend.svg")
    write_box(cleaned, out / "box_spend_by_segment.svg")
    write_line(cleaned, out / "line_daily_sales.svg")
    write_outlier(cleaned, out / "outlier_zscore.svg")

    report = f"""# Data Audit Report Bab 03B

## Audit mentah

```json
{json.dumps(audit_report, ensure_ascii=False, indent=2)}
```

## Cleaning log

""" + "\n".join(f"- {item}" for item in cleaning_log) + f"""

## Statistik belanja

- mean: {mean(spends):.2f}
- median: {median(spends):.2f}
- Q1: {q1:.2f}
- Q3: {q3:.2f}
- IQR: {iqr:.2f}
- batas atas outlier IQR: {upper:.2f}

## Regresi linear insight

- spend_hat = {w:.3f} * visits + {b:.3f}
- residual terbesar: {max(residuals, key=abs):.2f}

## Insight awal

- Payment paling sering: {Counter(r.payment for r in cleaned).most_common(1)[0][0]}
- Device paling sering: {Counter(r.device for r in cleaned).most_common(1)[0][0]}
- Nilai belanja di atas batas IQR perlu diperiksa, bukan otomatis dihapus.
- Split manifest disimpan untuk mencegah evaluasi tidak reproducible.
"""
    (out / "data_audit_report.md").write_text(report, encoding="utf-8")

    print("Audit:", audit_report)
    print("Cleaning log:", cleaning_log)
    print("Linear regression: spend_hat =", round(w, 3), "* visits +", round(b, 3))
    print("Outputs written to", out)


if __name__ == "__main__":
    main()
