{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure rendering — v4.0.3\n",
    "\n",
    "Produces the 5 paper figures listed in spec Appendix G:\n",
    "- `fig1.{pdf,png}` — per-tensor stacked bar (sign / top 7 exp bits / low exp bit / 7 fraction bits) grouped by tensor category\n",
    "- `fig2.{pdf,png}` — per-tensor best-method ratio vs `R_marginal`, y=x line, separate panels for bf16 and Q4_K\n",
    "- `fig3.{pdf,png}` — distribution of per-tensor H(nibble) across 1,043 Q4_K tensors, 4.0-bit uniform ceiling annotated\n",
    "- `fig4.{pdf,png}` — distribution of 250 Pearson correlations across (model, role, K), model-level CI band\n",
    "- `fig5.{pdf,png}` — decompress MB/s (log) vs byte-weighted ratio, all methods, Pareto frontier highlighted, color-coded by method class\n",
    "\n",
    "All notebooks read `../results/results.jsonl.zst` as their data source. No external fetches at run time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, math, pathlib\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import zstandard\n",
    "\n",
    "RESULTS = pathlib.Path('../results/results.jsonl.zst')\n",
    "ATLAS_BF16 = pathlib.Path('../provenance/iter3_precommit/03_atlas/results.jsonl')\n",
    "ATLAS_Q4K  = pathlib.Path('../provenance/iter5/precheck_q/atlas_train.jsonl')\n",
    "ATLAS_LAYER = pathlib.Path('../provenance/iter5/precheck_l/atlas_train.jsonl')\n",
    "FIG_DIR = pathlib.Path('../figures')\n",
    "FIG_DIR.mkdir(exist_ok=True)\n",
    "\n",
    "def load_jsonl(path):\n",
    "    p = pathlib.Path(path)\n",
    "    if p.suffix == '.zst':\n",
    "        with p.open('rb') as f:\n",
    "            raw = zstandard.ZstdDecompressor().decompress(f.read(), max_output_size=2 << 30)\n",
    "        for line in raw.decode('utf-8').splitlines():\n",
    "            if line.strip():\n",
    "                yield json.loads(line)\n",
    "    else:\n",
    "        with p.open() as f:\n",
    "            for line in f:\n",
    "                if line.strip():\n",
    "                    yield json.loads(line)\n",
    "\n",
    "def save_fig(fig, stem):\n",
    "    for ext in ('pdf', 'png'):\n",
    "        path = FIG_DIR / f'{stem}.{ext}'\n",
    "        fig.savefig(path, bbox_inches='tight', dpi=150 if ext == 'png' else None)\n",
    "        print(f'  wrote {path}')\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 1 — bf16 entropy decomposition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = [r for r in load_jsonl(ATLAS_BF16) if r.get('dtype') == 'bf16']\n",
    "by_cat = defaultdict(list)\n",
    "for r in rows:\n",
    "    by_cat[r.get('tensor_category', 'other')].append(r)\n",
    "fig, ax = plt.subplots(figsize=(11, 5))\n",
    "cats = [c for c in ['embedding','attention','mlp','lm_head','norm','other'] if c in by_cat]\n",
    "bar_idx = 0; x_labels = []; x_ticks = []\n",
    "for cat in cats:\n",
    "    cat_rows = sorted(by_cat[cat], key=lambda r: r.get('input_bytes', 0))[:20]\n",
    "    start = bar_idx\n",
    "    for r in cat_rows:\n",
    "        h_sign = r.get('H_sign', 0); h_exp_high = r.get('H_exp_high', r.get('H_byte1', 0))\n",
    "        h_exp_low = r.get('H_exp_low', 0); h_mant = r.get('H_mantissa', r.get('H_byte0', 0))\n",
    "        if not any([h_sign, h_exp_high, h_exp_low, h_mant]): continue\n",
    "        b = 0\n",
    "        for h, color in [(h_sign,'#888'),(h_exp_high,'#1f77b4'),(h_exp_low,'#ff7f0e'),(h_mant,'#2ca02c')]:\n",
    "            ax.bar(bar_idx, h, bottom=b, color=color)\n",
    "            b += h\n",
    "        bar_idx += 1\n",
    "    if bar_idx > start:\n",
    "        x_ticks.append((start + bar_idx) / 2)\n",
    "        x_labels.append(cat)\n",
    "ax.axhline(16, color='r', ls='--', lw=0.8)\n",
    "ax.set_xticks(x_ticks); ax.set_xticklabels(x_labels)\n",
    "ax.set_ylabel('bits/value'); ax.set_title('Fig 1 - bf16 byte-marginal entropy decomposition')\n",
    "save_fig(fig, 'fig1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 2 — ratio vs ceiling scatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atlas_bf16 = {r['tensor_name']: r for r in load_jsonl(ATLAS_BF16) if r.get('dtype') == 'bf16'}\n",
    "bf16_method = {}\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if r.get('_kind') == 'bench_bf16_perfile' and 'methods' in r:\n",
    "        m = r['methods'].get('bf16_split')\n",
    "        if m and 'ratio' in m:\n",
    "            bf16_method[r.get('tensor_name','')] = m['ratio']\n",
    "pts_bf = [(atlas_bf16[k].get('R_marginal', atlas_bf16[k].get('ratio_ceiling_byte', None)), v)\n",
    "          for k, v in bf16_method.items() if k in atlas_bf16]\n",
    "pts_bf = [(x, y) for x, y in pts_bf if x and y]\n",
    "\n",
    "atlas_q4k = {r['tensor_name']: r for r in load_jsonl(ATLAS_Q4K)}\n",
    "qk_method = {}\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if r.get('_kind') == 'Q4_K_benchmark' and 'methods' in r:\n",
    "        m = r['methods'].get('qb_k4')\n",
    "        if m and 'ratio' in m:\n",
    "            qk_method[r.get('tensor_name','')] = m['ratio']\n",
    "pts_q4 = []\n",
    "for k, v in qk_method.items():\n",
    "    a = atlas_q4k.get(k)\n",
    "    if a:\n",
    "        ceil = a.get('ratio_full', a.get('ratio_ceiling_full', None))\n",
    "        if ceil: pts_q4.append((ceil, v))\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 5))\n",
    "for ax, pts, title, axmax in [(ax1, pts_bf, 'bf16: bf16_split vs atlas', 1.6),\n",
    "                              (ax2, pts_q4, 'Q4_K: qb_k4 vs atlas', 1.15)]:\n",
    "    if pts:\n",
    "        xs, ys = zip(*pts)\n",
    "        ax.scatter(xs, ys, s=4, alpha=0.5)\n",
    "        ax.plot([1, axmax], [1, axmax], 'r--', lw=0.8)\n",
    "    ax.set_xlabel('R_marginal (atlas)'); ax.set_ylabel('best-method ratio'); ax.set_title(title)\n",
    "fig.suptitle('Fig 2 - per-tensor ratio vs R_marginal scatter')\n",
    "save_fig(fig, 'fig2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 3 — Q4_K nibble entropy histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nibble_H = [r['H_nibble'] for r in load_jsonl(ATLAS_Q4K) if 'H_nibble' in r]\n",
    "fig, ax = plt.subplots(figsize=(8, 4))\n",
    "ax.hist(nibble_H, bins=40, color='#1f77b4', edgecolor='white')\n",
    "ax.axvline(4.0, color='r', ls='--', lw=1, label='uniform ceiling 4.0')\n",
    "med = float(np.median(nibble_H))\n",
    "ax.axvline(med, color='k', ls=':', lw=1, label=f'median = {med:.3f}')\n",
    "ax.set_xlabel('H(nibble) bits/symbol'); ax.set_ylabel('n tensors')\n",
    "ax.set_title(f'Fig 3 - Q4_K nibble entropy (n = {len(nibble_H)})'); ax.legend()\n",
    "save_fig(fig, 'fig3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 4 — inter-layer Pearson correlation histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = [r['corr_pearson'] for r in load_jsonl(ATLAS_LAYER) if 'corr_pearson' in r]\n",
    "fig, ax = plt.subplots(figsize=(8, 4))\n",
    "ax.hist(corr, bins=50, color='#2ca02c', edgecolor='white')\n",
    "ax.axvline(0, color='r', ls='--', lw=1, label='zero')\n",
    "med = float(np.median(corr))\n",
    "ax.axvline(med, color='k', ls=':', lw=1, label=f'median = {med:+.4f}')\n",
    "ax.set_xlabel('Pearson correlation'); ax.set_ylabel('n pairs')\n",
    "ax.set_title(f'Fig 4 - inter-layer correlation (n = {len(corr)})'); ax.legend()\n",
    "save_fig(fig, 'fig4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 5 — throughput / ratio Pareto"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "method_ratios = defaultdict(list); method_mbps = defaultdict(list)\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if 'methods' not in r: continue\n",
    "    iter_tag = r.get('_iter','?')\n",
    "    for m_name, m_data in r['methods'].items():\n",
    "        if not isinstance(m_data, dict): continue\n",
    "        ratio = m_data.get('ratio'); mbps = m_data.get('decompress_MBps')\n",
    "        if ratio is not None and mbps is not None:\n",
    "            method_ratios[(iter_tag, m_name)].append(ratio)\n",
    "            method_mbps[(iter_tag, m_name)].append(mbps)\n",
    "agg = []\n",
    "for key in method_ratios:\n",
    "    rs = method_ratios[key]; ms = method_mbps[key]\n",
    "    if len(rs) < 5: continue\n",
    "    geo = math.exp(sum(math.log(x) for x in rs if x > 0) / len(rs))\n",
    "    agg.append((key, geo, float(np.median(ms)), len(rs)))\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "for (iter_tag, m_name), r, mbps, n in agg:\n",
    "    color = {'iter3_v3':'#1f77b4','iter4':'#ff7f0e','iter4_7B':'#ff7f0e','iter4_M7':'#9467bd','iter5_Q':'#2ca02c','iter6':'#d62728'}.get(iter_tag, '#888')\n",
    "    ax.scatter(mbps, r, s=max(20, math.log1p(n)*8), c=color, alpha=0.8)\n",
    "    ax.annotate(m_name, (mbps, r), fontsize=6, alpha=0.6, xytext=(3,3), textcoords='offset points')\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel('decompress MB/s (log)'); ax.set_ylabel('byte-weighted geomean ratio')\n",
    "ax.set_title('Fig 5 - throughput / ratio Pareto')\n",
    "ax.grid(True, ls=':', alpha=0.3)\n",
    "save_fig(fig, 'fig5')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
  "language_info": {"name": "python", "version": "3.11"}
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
