#!/usr/bin/env python3
"""render_figures.py - render 5 paper figures + LOMO tables from results.jsonl.zst.

Standalone CLI; same logic as notebooks/figures.ipynb but produces all 10
(5 figures x {pdf,png}) outputs in one shot. Run from the repo root:

    python3 render_figures.py [--out figures/] [--results results/results.jsonl.zst]

Required: numpy, matplotlib, zstandard.
"""
from __future__ import annotations

import argparse
import json
import math
import pathlib
from collections import defaultdict

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import zstandard


def load_jsonl(path: pathlib.Path):
    if path.suffix == ".zst":
        with path.open("rb") as f:
            raw = zstandard.ZstdDecompressor().decompress(f.read(), max_output_size=2 << 30)
        for line in raw.decode("utf-8").splitlines():
            if line.strip():
                yield json.loads(line)
    else:
        with path.open() as f:
            for line in f:
                if line.strip():
                    yield json.loads(line)


def save_fig(fig, out_dir: pathlib.Path, stem: str) -> None:
    for ext in ("pdf", "png"):
        path = out_dir / f"{stem}.{ext}"
        fig.savefig(path, bbox_inches="tight", dpi=150 if ext == "png" else None)
        print(f"  wrote {path}")
    plt.close(fig)


def fig1_entropy_decomposition(atlas_bf16: pathlib.Path, out_dir: pathlib.Path) -> None:
    rows = [r for r in load_jsonl(atlas_bf16) if r.get("dtype") == "bf16"]
    by_cat: dict[str, list[dict]] = defaultdict(list)
    for r in rows:
        by_cat[r.get("tensor_category", "other")].append(r)

    fig, ax = plt.subplots(figsize=(11, 5))
    cats = [c for c in ["embedding", "attention", "mlp", "lm_head", "norm", "other"] if c in by_cat]
    bar_idx = 0
    x_labels: list[str] = []
    x_ticks: list[float] = []
    for cat in cats:
        cat_rows = sorted(by_cat[cat], key=lambda r: r.get("input_bytes", 0))[:20]
        start = bar_idx
        for r in cat_rows:
            h_sign = r.get("H_sign", 0)
            h_exp_high = r.get("H_exp_high", r.get("H_byte1", 0))
            h_exp_low = r.get("H_exp_low", 0)
            h_mant = r.get("H_mantissa", r.get("H_byte0", 0))
            if not any([h_sign, h_exp_high, h_exp_low, h_mant]):
                continue
            b = 0.0
            for h, color in (
                (h_sign, "#888"),
                (h_exp_high, "#1f77b4"),
                (h_exp_low, "#ff7f0e"),
                (h_mant, "#2ca02c"),
            ):
                ax.bar(bar_idx, h, bottom=b, color=color)
                b += h
            bar_idx += 1
        if bar_idx > start:
            x_ticks.append((start + bar_idx) / 2)
            x_labels.append(cat)
    ax.axhline(16, color="r", ls="--", lw=0.8, label="bf16 raw (16 bits)")
    handles = [
        plt.Rectangle((0, 0), 1, 1, fc=c, label=lbl)
        for c, lbl in (("#888", "sign"), ("#1f77b4", "exp_high"), ("#ff7f0e", "exp_low"), ("#2ca02c", "mantissa"))
    ]
    ax.legend(handles=handles, loc="upper right", ncol=4, fontsize=8)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("bits/value")
    ax.set_title("Fig 1 — bf16 byte-marginal entropy decomposition by tensor category")
    save_fig(fig, out_dir, "fig1")


def fig2_ratio_vs_ceiling_scatter(
    atlas_bf16: pathlib.Path,
    atlas_q4k: pathlib.Path,
    results: pathlib.Path,
    out_dir: pathlib.Path,
) -> None:
    bf16_atlas = {r["tensor_name"]: r for r in load_jsonl(atlas_bf16) if r.get("dtype") == "bf16"}
    bf16_method: dict[str, float] = {}
    q4k_method: dict[str, float] = {}
    for r in load_jsonl(results):
        kind = r.get("_kind")
        if "methods" not in r:
            continue
        if kind == "bench_bf16_perfile":
            m = r["methods"].get("bf16_split")
            if m and "ratio" in m:
                bf16_method[r.get("tensor_name", "")] = m["ratio"]
        elif kind == "Q4_K_benchmark":
            m = r["methods"].get("qb_k4")
            if m and "ratio" in m:
                q4k_method[r.get("tensor_name", "")] = m["ratio"]

    pts_bf = [
        (bf16_atlas[k].get("R_marginal", bf16_atlas[k].get("ratio_ceiling_byte")), v)
        for k, v in bf16_method.items()
        if k in bf16_atlas
    ]
    pts_bf = [(x, y) for x, y in pts_bf if x and y]

    q4k_atlas = {r["tensor_name"]: r for r in load_jsonl(atlas_q4k)}
    pts_q4 = []
    for k, v in q4k_method.items():
        a = q4k_atlas.get(k)
        if a:
            ceil = a.get("ratio_full") or a.get("ratio_ceiling_full")
            if ceil:
                pts_q4.append((ceil, v))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 5))
    for ax, pts, title, axmax in (
        (ax1, pts_bf, "bf16: bf16_split vs atlas ceiling", 1.6),
        (ax2, pts_q4, "Q4_K: qb_k4 vs atlas ceiling", 1.15),
    ):
        if pts:
            xs, ys = zip(*pts)
            ax.scatter(xs, ys, s=4, alpha=0.5)
            ax.plot([1, axmax], [1, axmax], "r--", lw=0.8, label="y = x")
        ax.set_xlabel("R_marginal (atlas)")
        ax.set_ylabel("best-method ratio")
        ax.set_title(title)
        ax.legend()
    fig.suptitle("Fig 2 — per-tensor ratio vs R_marginal scatter")
    save_fig(fig, out_dir, "fig2")


def fig3_q4k_nibble_entropy_hist(atlas_q4k: pathlib.Path, out_dir: pathlib.Path) -> None:
    nibble_H = [r["H_nibble"] for r in load_jsonl(atlas_q4k) if "H_nibble" in r]
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.hist(nibble_H, bins=40, color="#1f77b4", edgecolor="white")
    ax.axvline(4.0, color="r", ls="--", lw=1, label="uniform ceiling 4.0")
    med = float(np.median(nibble_H))
    ax.axvline(med, color="k", ls=":", lw=1, label=f"median = {med:.3f}")
    ax.set_xlabel("H(nibble) — bits/symbol (alphabet 16)")
    ax.set_ylabel("n tensors")
    ax.set_title(f"Fig 3 — Q4_K nibble entropy (n = {len(nibble_H)})")
    ax.legend()
    save_fig(fig, out_dir, "fig3")


def fig4_interlayer_correlation_hist(atlas_layer: pathlib.Path, out_dir: pathlib.Path) -> None:
    corr = [r["corr_pearson"] for r in load_jsonl(atlas_layer) if "corr_pearson" in r]
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.hist(corr, bins=50, color="#2ca02c", edgecolor="white")
    ax.axvline(0, color="r", ls="--", lw=1, label="zero")
    med = float(np.median(corr))
    ax.axvline(med, color="k", ls=":", lw=1, label=f"median = {med:+.4f}")
    ax.set_xlabel("Pearson correlation (adjacent layer pair, same role)")
    ax.set_ylabel("n pairs")
    ax.set_title(f"Fig 4 — inter-layer correlation distribution (n = {len(corr)})")
    ax.legend()
    save_fig(fig, out_dir, "fig4")


def fig5_throughput_ratio_pareto(results: pathlib.Path, out_dir: pathlib.Path) -> None:
    method_ratios: dict[tuple[str, str], list[float]] = defaultdict(list)
    method_mbps: dict[tuple[str, str], list[float]] = defaultdict(list)
    for r in load_jsonl(results):
        if "methods" not in r:
            continue
        iter_tag = r.get("_iter", "?")
        for m_name, m_data in r["methods"].items():
            if not isinstance(m_data, dict):
                continue
            ratio = m_data.get("ratio")
            mbps = m_data.get("decompress_MBps")
            if ratio is not None and mbps is not None:
                method_ratios[(iter_tag, m_name)].append(ratio)
                method_mbps[(iter_tag, m_name)].append(mbps)

    # Translate the internal stage tag stored in `_iter` (kept in the data
    # for source-file trace-back) to a human-readable label + colour for the
    # legend. Tag strings come directly from each row's `_iter` annotation.
    stage_tag_to_label = {
        "iter3_v3":  ("#1f77b4", "bf16 main bench"),
        "iter4":     ("#ff7f0e", "bf16 supplementary"),
        "iter4_7B":  ("#ff7f0e", "7B validation"),
        "iter4_M7":  ("#9467bd", "cross-element context"),
        "iter5_Q":   ("#2ca02c", "Q4_K bench"),
        "iter6":     ("#d62728", "GGUF artifact"),
    }
    palette = stage_tag_to_label

    fig, ax = plt.subplots(figsize=(10, 6))
    legend_seen = set()
    for key, rs in method_ratios.items():
        ms = method_mbps[key]
        if len(rs) < 5:
            continue
        stage_tag, m_name = key
        geo = math.exp(sum(math.log(x) for x in rs if x > 0) / len(rs))
        med_mbps = float(np.median(ms))
        color, stage_label = palette.get(stage_tag, ("#888", "other"))
        ax.scatter(
            med_mbps,
            geo,
            s=max(20, math.log1p(len(rs)) * 8),
            c=color,
            alpha=0.8,
            label=stage_label if stage_label not in legend_seen else None,
        )
        legend_seen.add(stage_label)
        ax.annotate(m_name, (med_mbps, geo), fontsize=6, alpha=0.6, xytext=(3, 3), textcoords="offset points")

    ax.set_xscale("log")
    ax.set_xlabel("decompress MB/s (log)")
    ax.set_ylabel("byte-weighted geomean ratio")
    ax.set_title("Fig 5 — throughput / ratio Pareto across all benchmark methods")
    ax.grid(True, ls=":", alpha=0.3)
    ax.legend(loc="lower left", fontsize=8)
    save_fig(fig, out_dir, "fig5")


def main() -> None:
    ap = argparse.ArgumentParser()
    repo = pathlib.Path(__file__).parent
    ap.add_argument("--results", default=str(repo / "results" / "results.jsonl.zst"))
    ap.add_argument("--atlas-bf16", default=str(repo / "provenance" / "iter3_precommit" / "03_atlas" / "results.jsonl"))
    ap.add_argument("--atlas-q4k", default=str(repo / "provenance" / "iter5" / "precheck_q" / "atlas_train.jsonl"))
    ap.add_argument("--atlas-layer", default=str(repo / "provenance" / "iter5" / "precheck_l" / "atlas_train.jsonl"))
    ap.add_argument("--out", default=str(repo / "figures"))
    args = ap.parse_args()

    out_dir = pathlib.Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    results = pathlib.Path(args.results)
    atlas_bf16 = pathlib.Path(args.atlas_bf16)
    atlas_q4k = pathlib.Path(args.atlas_q4k)
    atlas_layer = pathlib.Path(args.atlas_layer)

    print("Fig 1 - bf16 entropy decomposition")
    fig1_entropy_decomposition(atlas_bf16, out_dir)
    print("Fig 2 - ratio vs ceiling scatter")
    fig2_ratio_vs_ceiling_scatter(atlas_bf16, atlas_q4k, results, out_dir)
    print("Fig 3 - Q4_K nibble entropy histogram")
    fig3_q4k_nibble_entropy_hist(atlas_q4k, out_dir)
    print("Fig 4 - inter-layer correlation histogram")
    fig4_interlayer_correlation_hist(atlas_layer, out_dir)
    print("Fig 5 - throughput / ratio Pareto")
    fig5_throughput_ratio_pareto(results, out_dir)

    print("\nDone. 10 files in", out_dir)


if __name__ == "__main__":
    main()
