{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure rendering\n",
    "\n",
    "Produces the 5 paper figures listed in paper Appendix G (Figures):\n",
    "- `fig1.{pdf,png}` - per-tensor stacked bar of byte-marginal entropies (byte 1: sign + exp_high; byte 0: exp_low + mantissa), grouped by tensor category\n",
    "- `fig2.{pdf,png}` - per-tensor best-method ratio vs `R_marginal`, y=x line, separate panels for bf16 (Prop. 1 ceiling, bf16_split) and Q4_K (nibble-stream marginal, qb_k4)\n",
    "- `fig3.{pdf,png}` - distribution of per-tensor H(nibble) across the Q4_K atlas, 4.0-bit uniform ceiling annotated\n",
    "- `fig4.{pdf,png}` - distribution of 250 Pearson correlations across (model, role, K), model-level 95% CI band [-0.002, +0.003] overlaid\n",
    "- `fig5.{pdf,png}` - decompress MB/s (log) vs byte-weighted ratio, classed by general-purpose / byte-grouped / format-aware / trained-profile, Pareto frontier highlighted\n",
    "\n",
    "Reads `../results/results.jsonl.zst` plus the four atlas files in `../provenance/`. No external fetches at run time.\n",
    "\n",
    "Identical logic lives in `../render_figures.py` (CLI). Editing this notebook is fine but please mirror any logic change to the CLI script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, math, pathlib\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import zstandard\n",
    "\n",
    "RESULTS = pathlib.Path('../results/results.jsonl.zst')\n",
    "ATLAS_BF16 = pathlib.Path('../provenance/iter3_precommit/03_atlas/results.jsonl')\n",
    "ATLAS_Q4K  = pathlib.Path('../provenance/iter5/precheck_q/atlas_train.jsonl')\n",
    "ATLAS_LAYER = pathlib.Path('../provenance/iter5/precheck_l/atlas_train.jsonl')\n",
    "FIG_DIR = pathlib.Path('../figures')\n",
    "FIG_DIR.mkdir(exist_ok=True)\n",
    "\n",
    "METHOD_CLASS = {\n",
    "    'gzip6': 'general-purpose', 'brotli_q9': 'general-purpose', 'xz9': 'general-purpose',\n",
    "    'whole_xz6': 'general-purpose', 'whole_zstd3': 'general-purpose', 'whole_zstd9': 'general-purpose',\n",
    "    'whole_zstd19': 'general-purpose', 'zstd3_raw': 'general-purpose', 'zstd19_raw': 'general-purpose',\n",
    "    'zstd19_bytegrouped': 'byte-grouped', 'decomp_perstream_zstd3': 'byte-grouped',\n",
    "    'decomp_perstream_zstd19': 'byte-grouped', 'decomp_perstream_xz6': 'byte-grouped',\n",
    "    'decomp_perstream_zstd19_bgscale': 'byte-grouped', 'decomp_qb_full': 'byte-grouped',\n",
    "    'bf16_split': 'format-aware', 'qb_k4': 'format-aware', 'qb_k8': 'format-aware',\n",
    "    'qb_k16': 'format-aware', 'qa': 'format-aware',\n",
    "    'scape_per_file': 'trained-profile', 'bf16_pretrained_k1': 'trained-profile',\n",
    "    'bf16_mixture_k4': 'trained-profile', 'zstd19_dict': 'trained-profile',\n",
    "}\n",
    "CLASS_STYLE = {\n",
    "    'general-purpose': ('#1f77b4', 'o'),\n",
    "    'byte-grouped':    ('#ff7f0e', 's'),\n",
    "    'format-aware':    ('#2ca02c', '^'),\n",
    "    'trained-profile': ('#d62728', 'D'),\n",
    "}\n",
    "\n",
    "def load_jsonl(path):\n",
    "    p = pathlib.Path(path)\n",
    "    if p.suffix == '.zst':\n",
    "        with p.open('rb') as f:\n",
    "            raw = zstandard.ZstdDecompressor().decompress(f.read(), max_output_size=2 << 30)\n",
    "        for line in raw.decode('utf-8').splitlines():\n",
    "            if line.strip():\n",
    "                yield json.loads(line)\n",
    "    else:\n",
    "        with p.open() as f:\n",
    "            for line in f:\n",
    "                if line.strip():\n",
    "                    yield json.loads(line)\n",
    "\n",
    "def save_fig(fig, stem):\n",
    "    for ext in ('pdf', 'png'):\n",
    "        path = FIG_DIR / f'{stem}.{ext}'\n",
    "        fig.savefig(path, bbox_inches='tight', dpi=150 if ext == 'png' else None)\n",
    "        print(f'  wrote {path}')\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 1 - bf16 byte-marginal entropy decomposition\n",
    "\n",
    "bf16 layout: byte 1 holds `[sign(1) + exp_high(7)]`, byte 0 holds `[exp_low(1) + mantissa(7)]`. The atlas measures byte-marginal entropies; under Prop. 1 the iid-byte-marginal ceiling is governed by their sum. The figure stacks the two byte entropies per tensor and draws the raw bf16 16 bits/value reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = [r for r in load_jsonl(ATLAS_BF16) if r.get('dtype') == 'bf16']\n",
    "by_cat = defaultdict(list)\n",
    "for r in rows:\n",
    "    by_cat[r.get('tensor_category', 'other')].append(r)\n",
    "\n",
    "cats = [c for c in ['embedding','attention','mlp','lm_head','norm','other'] if c in by_cat]\n",
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "bar_idx = 0; x_labels = []; x_ticks = []; legend_seen = False\n",
    "for cat in cats:\n",
    "    cat_rows = sorted(by_cat[cat], key=lambda r: r.get('tensor_name',''))[:30]\n",
    "    start = bar_idx\n",
    "    for r in cat_rows:\n",
    "        h_byte1 = r.get('H_byte1', 0); h_byte0 = r.get('H_byte0', 0)\n",
    "        if not (h_byte1 and h_byte0): continue\n",
    "        label1 = 'byte 1: sign + exp_high' if not legend_seen else None\n",
    "        label0 = 'byte 0: exp_low + mantissa' if not legend_seen else None\n",
    "        ax.bar(bar_idx, h_byte1, color='#1f77b4', label=label1)\n",
    "        ax.bar(bar_idx, h_byte0, bottom=h_byte1, color='#2ca02c', label=label0)\n",
    "        legend_seen = True\n",
    "        bar_idx += 1\n",
    "    if bar_idx > start:\n",
    "        x_ticks.append((start + bar_idx - 1) / 2)\n",
    "        x_labels.append(cat)\n",
    "        ax.axvline(bar_idx - 0.5, color='#cccccc', lw=0.5, zorder=0)\n",
    "\n",
    "ax.axhline(16, color='r', ls='--', lw=0.8, label='raw bf16 (16 bits/value)')\n",
    "ax.set_xticks(x_ticks); ax.set_xticklabels(x_labels)\n",
    "ax.set_xlim(-0.5, bar_idx - 0.5); ax.set_ylim(0, 17)\n",
    "ax.set_ylabel('byte-marginal entropy (bits / value)')\n",
    "ax.set_xlabel('tensor (grouped by category)')\n",
    "ax.set_title('Figure 1 - bf16 byte-marginal entropy decomposition (per tensor, by category)')\n",
    "ax.legend(loc='lower right', fontsize=8, framealpha=0.9)\n",
    "save_fig(fig, 'fig1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 2 - per-tensor achieved ratio vs R_marginal\n",
    "\n",
    "Top panel (bf16): `bf16_split` is in Prop. 1's coder class, so per-tensor achieved ratio sits on or just below `R_marginal` (= atlas `ceiling_ratio`).\n",
    "\n",
    "Bottom panel (Q4_K): `qb_k4` uses within-superblock joint structure - outside Prop. 1's class - so achieved ratio (full Q4_K tensor stream) sits above the per-tensor nibble-stream marginal ceiling (8 / H_byte). The per-tensor full-stream marginal ceiling is not in the atlas; the §7.2 1.076$\\times$ bound is derived globally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bf16_atlas = {r['tensor_name']: r for r in load_jsonl(ATLAS_BF16) if r.get('dtype') == 'bf16'}\n",
    "bf16_method = {}\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if r.get('_kind') == 'bench_bf16_perfile' and 'methods' in r:\n",
    "        m = r['methods'].get('bf16_split')\n",
    "        if m and 'ratio' in m:\n",
    "            bf16_method[r.get('tensor_name','')] = m['ratio']\n",
    "pts_bf = []\n",
    "for k, v in bf16_method.items():\n",
    "    a = bf16_atlas.get(k)\n",
    "    if not a: continue\n",
    "    x = a.get('ceiling_ratio') or a.get('R_marginal') or a.get('ratio_ceiling_byte')\n",
    "    if x and v: pts_bf.append((x, v))\n",
    "\n",
    "atlas_q4k = {r['tensor_name']: r for r in load_jsonl(ATLAS_Q4K)}\n",
    "qk_method = {}\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if r.get('_kind') == 'Q4_K_benchmark' and 'methods' in r:\n",
    "        m = r['methods'].get('qb_k4')\n",
    "        if m and 'ratio' in m:\n",
    "            qk_method[r.get('tensor_name','')] = m['ratio']\n",
    "pts_q4 = []\n",
    "for k, v in qk_method.items():\n",
    "    a = atlas_q4k.get(k)\n",
    "    if not a: continue\n",
    "    h_byte = a.get('H_byte')\n",
    "    if h_byte: pts_q4.append((8.0/h_byte, v))\n",
    "\n",
    "fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(8, 9.5))\n",
    "if pts_bf:\n",
    "    xs, ys = zip(*pts_bf)\n",
    "    ax_top.scatter(xs, ys, s=14, alpha=0.55, c='#1f77b4', edgecolors='none')\n",
    "ax_top.plot([1.40,1.55],[1.40,1.55],'r--',lw=1.0,label='y = x (R_marginal)')\n",
    "ax_top.set_xlim(1.40,1.55); ax_top.set_ylim(1.40,1.55)\n",
    "ax_top.set_xlabel('R_marginal - iid byte-marginal ceiling, atlas')\n",
    "ax_top.set_ylabel('bf16_split achieved ratio')\n",
    "ax_top.set_title(f'bf16 (n = {len(pts_bf)}) - achieved ratio vs R_marginal, Prop. 1 ceiling')\n",
    "ax_top.legend(loc='upper left'); ax_top.grid(True, ls=':', alpha=0.3)\n",
    "\n",
    "if pts_q4:\n",
    "    xs, ys = zip(*pts_q4)\n",
    "    ax_bot.scatter(xs, ys, s=14, alpha=0.55, c='#1f77b4', edgecolors='none')\n",
    "ax_bot.plot([1.02,1.06],[1.02,1.06],'r--',lw=1.0,label='y = x (nibble-stream marginal)')\n",
    "ax_bot.set_xlim(1.02,1.06); ax_bot.set_ylim(1.02,1.06)\n",
    "ax_bot.set_xlabel('nibble-stream marginal ceiling: 8 / H_byte (atlas)')\n",
    "ax_bot.set_ylabel('qb_k4 achieved ratio (full Q4_K tensor stream)')\n",
    "ax_bot.set_title(f'Q4_K (n = {len(pts_q4)}) - achieved ratio vs nibble-only ceiling')\n",
    "ax_bot.legend(loc='upper left'); ax_bot.grid(True, ls=':', alpha=0.3)\n",
    "\n",
    "fig.suptitle('Figure 2 - per-tensor achieved ratio vs R_marginal (Prop. 1 ceiling)')\n",
    "fig.tight_layout()\n",
    "save_fig(fig, 'fig2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 3 - Q4_K nibble entropy histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nibble_H = [r['H_nibble'] for r in load_jsonl(ATLAS_Q4K) if 'H_nibble' in r]\n",
    "fig, ax = plt.subplots(figsize=(9, 4.5))\n",
    "ax.hist(nibble_H, bins=40, color='#1f77b4', edgecolor='white')\n",
    "ax.axvline(4.0, color='r', ls='--', lw=1.2, label='uniform ceiling (4.0 bits)')\n",
    "med = float(np.median(nibble_H))\n",
    "ax.axvline(med, color='k', ls=':', lw=1.0, label=f'median = {med:.3f} bits')\n",
    "ax.set_xlabel('H(nibble) - bits / symbol (alphabet 16)')\n",
    "ax.set_ylabel('number of Q4_K tensors')\n",
    "ax.set_xlim(3.80, 4.02)\n",
    "ax.set_title(f'Figure 3 - Q4_K nibble entropy distribution (n = {len(nibble_H)} Q4_K-typed tensors)')\n",
    "ax.legend(loc='upper left')\n",
    "ax.grid(True, ls=':', alpha=0.3)\n",
    "save_fig(fig, 'fig3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 4 - inter-layer Pearson correlation histogram (cross-layer null)\n",
    "\n",
    "Model-level 95% CI band [-0.002, +0.003] from paper §8.2 is overlaid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corr = [r['corr_pearson'] for r in load_jsonl(ATLAS_LAYER) if 'corr_pearson' in r]\n",
    "fig, ax = plt.subplots(figsize=(9, 4.5))\n",
    "ax.hist(corr, bins=50, color='#2ca02c', edgecolor='white')\n",
    "ax.axvspan(-0.002, 0.003, color='#ffe4b3', alpha=0.6, zorder=0,\n",
    "           label='model-level 95% CI [-0.002, +0.003]')\n",
    "ax.axvline(0, color='r', ls='--', lw=1.0, label='zero (no correlation)')\n",
    "med = float(np.median(corr))\n",
    "ax.axvline(med, color='k', ls=':', lw=1.0, label=f'median = {med:+.4f}')\n",
    "ax.set_xlabel('Pearson correlation, adjacent same-role layer pair')\n",
    "ax.set_ylabel('number of pairs')\n",
    "ax.set_title(f'Figure 4 - inter-layer correlation distribution (n = {len(corr)} pairs across two Qwen2.5 source models)')\n",
    "ax.legend(loc='upper right', fontsize=9, framealpha=0.9)\n",
    "ax.grid(True, ls=':', alpha=0.3)\n",
    "save_fig(fig, 'fig4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 5 - throughput / ratio Pareto across all measured methods\n",
    "\n",
    "One point per (corpus, method): bf16 methods and Q4_K methods are not conflated. Colour/marker shows method class; Pareto-front points have a black outline and are connected by a thin polyline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _is_pareto_optimal(points):\n",
    "    n = len(points); mask = [True]*n\n",
    "    for i,(xi,yi) in enumerate(points):\n",
    "        for j,(xj,yj) in enumerate(points):\n",
    "            if i==j: continue\n",
    "            if xj>=xi and yj>=yi and (xj>xi or yj>yi):\n",
    "                mask[i] = False; break\n",
    "    return mask\n",
    "\n",
    "corpus_for_kind = {'bench_bf16_perfile':'bf16', 'bench_7B':'bf16', 'Q4_K_benchmark':'Q4_K'}\n",
    "agg_raw = defaultdict(list)\n",
    "for r in load_jsonl(RESULTS):\n",
    "    if 'methods' not in r: continue\n",
    "    corpus = corpus_for_kind.get(r.get('_kind',''))\n",
    "    if not corpus: continue\n",
    "    ib = r.get('input_bytes', 0) or 0\n",
    "    for mn, md in r['methods'].items():\n",
    "        if not isinstance(md, dict): continue\n",
    "        ratio = md.get('ratio'); mbps = md.get('decompress_MBps')\n",
    "        if ratio and mbps and ratio>0 and mbps>0:\n",
    "            agg_raw[(corpus, mn)].append((ratio, mbps, ib))\n",
    "\n",
    "points = []\n",
    "for (corpus, mn), vals in agg_raw.items():\n",
    "    if len(vals) < 5: continue\n",
    "    total_b = sum(v[2] for v in vals) or len(vals)\n",
    "    bw = sum(v[0]*(v[2] or 1) for v in vals) / total_b\n",
    "    mbps = float(np.median([v[1] for v in vals]))\n",
    "    points.append((corpus, mn, bw, mbps, len(vals)))\n",
    "\n",
    "xy = [(mbps, ratio) for (_,_,ratio,mbps,_) in points]\n",
    "mask = _is_pareto_optimal(xy)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 7))\n",
    "seen = set()\n",
    "for (corpus, mn, ratio, mbps, n), is_p in zip(points, mask):\n",
    "    cls = METHOD_CLASS.get(mn, 'general-purpose')\n",
    "    color, marker = CLASS_STYLE[cls]\n",
    "    label = cls if cls not in seen else None\n",
    "    seen.add(cls)\n",
    "    edge = 'black' if is_p else 'none'\n",
    "    lw = 1.4 if is_p else 0\n",
    "    ax.scatter(mbps, ratio, s=max(40, math.log1p(n)*18),\n",
    "               c=color, marker=marker, alpha=0.85,\n",
    "               edgecolors=edge, linewidths=lw, label=label)\n",
    "    ax.annotate(f'{mn} ({corpus})', (mbps, ratio), fontsize=6.5,\n",
    "                alpha=0.75, xytext=(4,4), textcoords='offset points')\n",
    "\n",
    "front = sorted([xy[i] for i, ok in enumerate(mask) if ok])\n",
    "if len(front) >= 2:\n",
    "    fx, fy = zip(*front)\n",
    "    ax.plot(fx, fy, 'k-', lw=0.9, alpha=0.5, zorder=0, label='Pareto front')\n",
    "\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel('decompress MB/s (log scale)')\n",
    "ax.set_ylabel('byte-weighted ratio')\n",
    "ax.set_title('Figure 5 - throughput / ratio Pareto across all measured methods (separate points per corpus)')\n",
    "ax.grid(True, ls=':', alpha=0.35, which='both')\n",
    "ax.legend(loc='center right', fontsize=9, framealpha=0.9)\n",
    "save_fig(fig, 'fig5')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
