#!/usr/bin/env python3
"""render_figures.py - render 5 paper figures from results.jsonl.zst.

Standalone CLI; same logic as notebooks/figures.ipynb but produces all 10
(5 figures x {pdf,png}) outputs in one shot. Run from the repo root:

    python3 render_figures.py [--out figures/] [--results results/results.jsonl.zst]

Required: numpy, matplotlib, zstandard.
"""
from __future__ import annotations

import argparse
import json
import math
import pathlib
from collections import defaultdict

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import zstandard


# -----------------------------------------------------------------------------
# Method classification for Fig 5 (paper §10.1 caption: general-purpose,
# byte-grouped, format-aware, trained-profile).
# -----------------------------------------------------------------------------
METHOD_CLASS = {
    # general-purpose codecs
    "gzip6": "general-purpose",
    "brotli_q9": "general-purpose",
    "xz9": "general-purpose",
    "whole_xz6": "general-purpose",
    "whole_zstd3": "general-purpose",
    "whole_zstd9": "general-purpose",
    "whole_zstd19": "general-purpose",
    "zstd3_raw": "general-purpose",
    "zstd19_raw": "general-purpose",
    # byte-grouped / decomposed-stream codecs
    "zstd19_bytegrouped": "byte-grouped",
    "decomp_perstream_zstd3": "byte-grouped",
    "decomp_perstream_zstd19": "byte-grouped",
    "decomp_perstream_xz6": "byte-grouped",
    "decomp_perstream_zstd19_bgscale": "byte-grouped",
    "decomp_qb_full": "byte-grouped",
    # format-aware methods (this paper's bf16_split and Q4_K mixture-CDF coder)
    "bf16_split": "format-aware",
    "qb_k4": "format-aware",
    "qb_k8": "format-aware",
    "qb_k16": "format-aware",
    "qa": "format-aware",
    # trained-profile methods (OpenZL and its variants, dictionary methods)
    "scape_per_file": "trained-profile",
    "bf16_pretrained_k1": "trained-profile",
    "bf16_mixture_k4": "trained-profile",
    "zstd19_dict": "trained-profile",
}

CLASS_STYLE = {
    "general-purpose": ("#1f77b4", "o"),
    "byte-grouped": ("#ff7f0e", "s"),
    "format-aware": ("#2ca02c", "^"),
    "trained-profile": ("#d62728", "D"),
}


def load_jsonl(path: pathlib.Path):
    if path.suffix == ".zst":
        with path.open("rb") as f:
            raw = zstandard.ZstdDecompressor().decompress(f.read(), max_output_size=2 << 30)
        for line in raw.decode("utf-8").splitlines():
            if line.strip():
                yield json.loads(line)
    else:
        with path.open() as f:
            for line in f:
                if line.strip():
                    yield json.loads(line)


def save_fig(fig, out_dir: pathlib.Path, stem: str) -> None:
    for ext in ("pdf", "png"):
        path = out_dir / f"{stem}.{ext}"
        fig.savefig(path, bbox_inches="tight", dpi=150 if ext == "png" else None)
        print(f"  wrote {path}")
    plt.close(fig)


def fig1_entropy_decomposition(atlas_bf16: pathlib.Path, out_dir: pathlib.Path) -> None:
    """Per-tensor stacked bar of bf16 byte-marginal entropies.

    bf16 layout: byte 1 holds [sign(1) + exp_high(7)], byte 0 holds
    [exp_low(1) + mantissa(7)]. The atlas measures byte-marginal entropies
    (H_byte1, H_byte0) - the entropy ceiling under Prop. 1 is governed by
    their sum, so the figure stacks them and compares against 16 bits/value.
    """
    rows = [r for r in load_jsonl(atlas_bf16) if r.get("dtype") == "bf16"]
    by_cat: dict[str, list[dict]] = defaultdict(list)
    for r in rows:
        by_cat[r.get("tensor_category", "other")].append(r)

    cats = [c for c in ["embedding", "attention", "mlp", "lm_head", "norm", "other"] if c in by_cat]

    fig, ax = plt.subplots(figsize=(12, 5))
    bar_idx = 0
    x_labels: list[str] = []
    x_ticks: list[float] = []
    legend_seen = False
    for cat in cats:
        # Order by tensor (alphabetical by tensor_name), cap per-category to
        # keep the figure readable - the picture is "every tensor sits near
        # the same place", not the per-tensor differences.
        cat_rows = sorted(by_cat[cat], key=lambda r: r.get("tensor_name", ""))[:30]
        start = bar_idx
        for r in cat_rows:
            h_byte1 = r.get("H_byte1", 0)  # sign + exp_high (7+1 = 8 bits)
            h_byte0 = r.get("H_byte0", 0)  # exp_low + mantissa (1+7 = 8 bits)
            if not (h_byte1 and h_byte0):
                continue
            label1 = "byte 1: sign + exp_high" if not legend_seen else None
            label0 = "byte 0: exp_low + mantissa" if not legend_seen else None
            ax.bar(bar_idx, h_byte1, color="#1f77b4", label=label1)
            ax.bar(bar_idx, h_byte0, bottom=h_byte1, color="#2ca02c", label=label0)
            legend_seen = True
            bar_idx += 1
        if bar_idx > start:
            x_ticks.append((start + bar_idx - 1) / 2)
            x_labels.append(cat)
            ax.axvline(bar_idx - 0.5, color="#cccccc", lw=0.5, zorder=0)

    ax.axhline(16, color="r", ls="--", lw=0.8, label="raw bf16 (16 bits/value)")
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_labels)
    ax.set_xlim(-0.5, bar_idx - 0.5)
    ax.set_ylim(0, 17)
    ax.set_ylabel("byte-marginal entropy (bits / value)")
    ax.set_xlabel("tensor (grouped by category)")
    ax.set_title("Figure 1 - bf16 byte-marginal entropy decomposition (per tensor, by category)")
    ax.legend(loc="lower right", fontsize=8, framealpha=0.9)
    save_fig(fig, out_dir, "fig1")


def fig2_ratio_vs_ceiling_scatter(
    atlas_bf16: pathlib.Path,
    atlas_q4k: pathlib.Path,
    results: pathlib.Path,
    out_dir: pathlib.Path,
) -> None:
    """Per-tensor best-method ratio vs per-tensor R_marginal, y=x line.

    Top panel: bf16 - bf16_split achieved ratio vs `ceiling_ratio` (which is
    R_marginal under iid byte-marginal coding from Prop. 1).
    Bottom panel: Q4_K - qb_k4 achieved ratio vs 8/H_byte (nibble-stream
    marginal ceiling derived from the per-tensor byte entropy).
    """
    bf16_atlas = {r["tensor_name"]: r for r in load_jsonl(atlas_bf16) if r.get("dtype") == "bf16"}
    bf16_method: dict[str, float] = {}
    q4k_method: dict[str, float] = {}
    for r in load_jsonl(results):
        kind = r.get("_kind")
        if "methods" not in r:
            continue
        if kind == "bench_bf16_perfile":
            m = r["methods"].get("bf16_split")
            if m and "ratio" in m:
                bf16_method[r.get("tensor_name", "")] = m["ratio"]
        elif kind == "Q4_K_benchmark":
            m = r["methods"].get("qb_k4")
            if m and "ratio" in m:
                q4k_method[r.get("tensor_name", "")] = m["ratio"]

    pts_bf: list[tuple[float, float]] = []
    for k, v in bf16_method.items():
        a = bf16_atlas.get(k)
        if not a:
            continue
        x = a.get("ceiling_ratio") or a.get("R_marginal") or a.get("ratio_ceiling_byte")
        if x and v:
            pts_bf.append((x, v))

    q4k_atlas = {r["tensor_name"]: r for r in load_jsonl(atlas_q4k)}
    pts_q4: list[tuple[float, float]] = []
    for k, v in q4k_method.items():
        a = q4k_atlas.get(k)
        if not a:
            continue
        h_byte = a.get("H_byte")
        if h_byte:
            ceil = 8.0 / h_byte  # nibble-stream byte-marginal ceiling
            pts_q4.append((ceil, v))

    fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(8, 9.5))

    # Top panel: bf16. bf16_split is in Prop. 1's coder class and cannot
    # exceed R_marginal; points cluster on / just below y = x.
    if pts_bf:
        xs, ys = zip(*pts_bf)
        ax_top.scatter(xs, ys, s=14, alpha=0.55, c="#1f77b4", edgecolors="none")
    x_lo, x_hi = 1.40, 1.55
    ax_top.plot([x_lo, x_hi], [x_lo, x_hi], "r--", lw=1.0, label="y = x (R_marginal)")
    ax_top.set_xlim(x_lo, x_hi)
    ax_top.set_ylim(x_lo, x_hi)
    ax_top.set_xlabel("R_marginal - iid byte-marginal ceiling, atlas")
    ax_top.set_ylabel("bf16_split achieved ratio")
    ax_top.set_title(
        f"bf16 (n = {len(pts_bf)}) - achieved ratio vs R_marginal, Prop. 1 ceiling"
    )
    ax_top.legend(loc="upper left")
    ax_top.grid(True, ls=":", alpha=0.3)

    # Bottom panel: Q4_K. qb_k4 uses within-superblock joint structure -
    # explicitly outside Prop. 1's iid-byte-marginal class - so the points
    # sit above the nibble-stream marginal line (the x axis here is the
    # nibble-only ceiling, 8 / H_byte, since the full Q4_K tensor-stream
    # ceiling is not measured per tensor in the atlas).
    if pts_q4:
        xs, ys = zip(*pts_q4)
        ax_bot.scatter(xs, ys, s=14, alpha=0.55, c="#1f77b4", edgecolors="none")
    x_lo, x_hi = 1.02, 1.06
    ax_bot.plot([x_lo, x_hi], [x_lo, x_hi], "r--", lw=1.0,
                label="y = x (nibble-stream marginal)")
    ax_bot.set_xlim(x_lo, x_hi)
    ax_bot.set_ylim(x_lo, x_hi)
    ax_bot.set_xlabel(
        "nibble-stream marginal ceiling: 8 / H_byte (atlas)"
    )
    ax_bot.set_ylabel("qb_k4 achieved ratio (full Q4_K tensor stream)")
    ax_bot.set_title(
        f"Q4_K (n = {len(pts_q4)}) - achieved ratio vs nibble-only ceiling"
    )
    ax_bot.legend(loc="upper left")
    ax_bot.grid(True, ls=":", alpha=0.3)

    fig.suptitle("Figure 2 - per-tensor achieved ratio vs R_marginal (Prop. 1 ceiling)")
    fig.tight_layout()
    save_fig(fig, out_dir, "fig2")


def fig3_q4k_nibble_entropy_hist(atlas_q4k: pathlib.Path, out_dir: pathlib.Path) -> None:
    """Histogram of per-tensor H(nibble) across the Q4_K atlas, with the
    4.0-bit uniform ceiling annotated."""
    nibble_H = [r["H_nibble"] for r in load_jsonl(atlas_q4k) if "H_nibble" in r]
    fig, ax = plt.subplots(figsize=(9, 4.5))
    ax.hist(nibble_H, bins=40, color="#1f77b4", edgecolor="white")
    ax.axvline(4.0, color="r", ls="--", lw=1.2, label="uniform ceiling (4.0 bits)")
    med = float(np.median(nibble_H))
    ax.axvline(med, color="k", ls=":", lw=1.0, label=f"median = {med:.3f} bits")
    ax.set_xlabel("H(nibble) - bits / symbol (alphabet 16)")
    ax.set_ylabel("number of Q4_K tensors")
    ax.set_xlim(3.80, 4.02)
    ax.set_title(
        f"Figure 3 - Q4_K nibble entropy distribution "
        f"(n = {len(nibble_H)} Q4_K-typed tensors)"
    )
    ax.legend(loc="upper left")
    ax.grid(True, ls=":", alpha=0.3)
    save_fig(fig, out_dir, "fig3")


def fig4_interlayer_correlation_hist(atlas_layer: pathlib.Path, out_dir: pathlib.Path) -> None:
    """Histogram of inter-layer Pearson correlations across (model, role,
    layer-K) tuples, with the model-level 95% CI band overlaid."""
    corr = [r["corr_pearson"] for r in load_jsonl(atlas_layer) if "corr_pearson" in r]
    fig, ax = plt.subplots(figsize=(9, 4.5))
    ax.hist(corr, bins=50, color="#2ca02c", edgecolor="white")

    # 95% CI band from paper §8.2 caption: [-0.002, +0.003]
    ax.axvspan(-0.002, 0.003, color="#ffe4b3", alpha=0.6, zorder=0,
               label="model-level 95% CI [-0.002, +0.003]")
    ax.axvline(0, color="r", ls="--", lw=1.0, label="zero (no correlation)")
    med = float(np.median(corr))
    ax.axvline(med, color="k", ls=":", lw=1.0, label=f"median = {med:+.4f}")

    ax.set_xlabel("Pearson correlation, adjacent same-role layer pair")
    ax.set_ylabel("number of pairs")
    ax.set_title(
        f"Figure 4 - inter-layer correlation distribution "
        f"(n = {len(corr)} pairs across two Qwen2.5 source models)"
    )
    ax.legend(loc="upper right", fontsize=9, framealpha=0.9)
    ax.grid(True, ls=":", alpha=0.3)
    save_fig(fig, out_dir, "fig4")


def _is_pareto_optimal(points: list[tuple[float, float]]) -> list[bool]:
    """Pareto front: points where no other point has higher ratio AND higher
    throughput. Returns boolean mask in the input order."""
    n = len(points)
    mask = [True] * n
    for i, (mbps_i, r_i) in enumerate(points):
        for j, (mbps_j, r_j) in enumerate(points):
            if i == j:
                continue
            if mbps_j >= mbps_i and r_j >= r_i and (mbps_j > mbps_i or r_j > r_i):
                mask[i] = False
                break
    return mask


def fig5_throughput_ratio_pareto(results: pathlib.Path, out_dir: pathlib.Path) -> None:
    """Decompress MB/s (log) vs byte-weighted ratio, classed by method-class
    (general-purpose / byte-grouped / format-aware / trained-profile), with
    Pareto front highlighted. One point per (corpus, method) so that bf16
    methods and Q4_K methods are not conflated."""
    # Aggregate by (corpus, method): list of (ratio, mbps, input_bytes). The
    # corpus tag is derived from `_kind` so a method that runs on both bf16
    # and Q4_K shows up as two separate points.
    corpus_for_kind = {
        "bench_bf16_perfile": "bf16",
        "bench_7B": "bf16",
        "Q4_K_benchmark": "Q4_K",
    }
    agg_raw: dict[tuple[str, str], list[tuple[float, float, int]]] = defaultdict(list)
    for r in load_jsonl(results):
        if "methods" not in r:
            continue
        corpus = corpus_for_kind.get(r.get("_kind", ""))
        if not corpus:
            continue
        ib = r.get("input_bytes", 0) or 0
        for m_name, m_data in r["methods"].items():
            if not isinstance(m_data, dict):
                continue
            ratio = m_data.get("ratio")
            mbps = m_data.get("decompress_MBps")
            if ratio is not None and mbps is not None and ratio > 0 and mbps > 0:
                agg_raw[(corpus, m_name)].append((ratio, mbps, ib))

    points: list[tuple[str, str, float, float, int]] = []
    for (corpus, m_name), vals in agg_raw.items():
        if len(vals) < 5:
            continue
        total_b = sum(v[2] for v in vals) or len(vals)
        # byte-weighted arithmetic mean of the ratio (matches paper convention)
        bw_ratio = sum(v[0] * (v[2] or 1) for v in vals) / total_b
        med_mbps = float(np.median([v[1] for v in vals]))
        points.append((corpus, m_name, bw_ratio, med_mbps, len(vals)))

    # Pareto front computed globally across all (corpus, method) points: a
    # point is Pareto-optimal if no other point has both higher mbps and
    # higher ratio.
    xy = [(mbps, ratio) for (_, _, ratio, mbps, _) in points]
    pareto_mask = _is_pareto_optimal(xy)

    fig, ax = plt.subplots(figsize=(12, 7))
    cls_seen: set[str] = set()
    for (corpus, m_name, ratio, mbps, n), is_pareto in zip(points, pareto_mask):
        cls = METHOD_CLASS.get(m_name, "general-purpose")
        color, marker = CLASS_STYLE[cls]
        label = cls if cls not in cls_seen else None
        cls_seen.add(cls)
        edge = "black" if is_pareto else "none"
        lw = 1.4 if is_pareto else 0
        ax.scatter(
            mbps, ratio,
            s=max(40, math.log1p(n) * 18),
            c=color, marker=marker, alpha=0.85,
            edgecolors=edge, linewidths=lw,
            label=label,
        )
        # Disambiguate the corpus in the label suffix when a method runs on
        # both bf16 and Q4_K (e.g. zstd19_bytegrouped).
        tag = f"{m_name} ({corpus})"
        ax.annotate(tag, (mbps, ratio), fontsize=6.5, alpha=0.75,
                    xytext=(4, 4), textcoords="offset points")

    # Pareto front polyline (sorted by mbps).
    front = sorted([xy[i] for i, ok in enumerate(pareto_mask) if ok])
    if len(front) >= 2:
        fx, fy = zip(*front)
        ax.plot(fx, fy, "k-", lw=0.9, alpha=0.5, zorder=0, label="Pareto front")

    ax.set_xscale("log")
    ax.set_xlabel("decompress MB/s (log scale)")
    ax.set_ylabel("byte-weighted ratio")
    ax.set_title(
        "Figure 5 - throughput / ratio Pareto across all measured methods "
        "(separate points per corpus)"
    )
    ax.grid(True, ls=":", alpha=0.35, which="both")
    ax.legend(loc="center right", fontsize=9, framealpha=0.9)
    save_fig(fig, out_dir, "fig5")


def main() -> None:
    ap = argparse.ArgumentParser()
    repo = pathlib.Path(__file__).parent
    ap.add_argument("--results", default=str(repo / "results" / "results.jsonl.zst"))
    ap.add_argument("--atlas-bf16", default=str(repo / "provenance" / "iter3_precommit" / "03_atlas" / "results.jsonl"))
    ap.add_argument("--atlas-q4k", default=str(repo / "provenance" / "iter5" / "precheck_q" / "atlas_train.jsonl"))
    ap.add_argument("--atlas-layer", default=str(repo / "provenance" / "iter5" / "precheck_l" / "atlas_train.jsonl"))
    ap.add_argument("--out", default=str(repo / "figures"))
    args = ap.parse_args()

    out_dir = pathlib.Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    results = pathlib.Path(args.results)
    atlas_bf16 = pathlib.Path(args.atlas_bf16)
    atlas_q4k = pathlib.Path(args.atlas_q4k)
    atlas_layer = pathlib.Path(args.atlas_layer)

    print("Fig 1 - bf16 entropy decomposition")
    fig1_entropy_decomposition(atlas_bf16, out_dir)
    print("Fig 2 - ratio vs ceiling scatter")
    fig2_ratio_vs_ceiling_scatter(atlas_bf16, atlas_q4k, results, out_dir)
    print("Fig 3 - Q4_K nibble entropy histogram")
    fig3_q4k_nibble_entropy_hist(atlas_q4k, out_dir)
    print("Fig 4 - inter-layer correlation histogram")
    fig4_interlayer_correlation_hist(atlas_layer, out_dir)
    print("Fig 5 - throughput / ratio Pareto")
    fig5_throughput_ratio_pareto(results, out_dir)

    print("\nDone. 10 files in", out_dir)


if __name__ == "__main__":
    main()
