import json
import numpy as np
from collections import defaultdict

rows = [json.loads(l) for l in open("/mnt/data/track_b_v5/results/direction_q/benchmark.jsonl")]
rows = [r for r in rows if "methods" in r]
print(f"total: {len(rows)}")

# Aggregate per method
methods = set()
for r in rows:
    methods.update(r["methods"].keys())
methods = sorted(methods)

agg = {}
for m in methods:
    ratios = []
    cmps = []; dms = []
    profile_bytes = []
    n_ver = 0
    for r in rows:
        md = r["methods"].get(m)
        if md and "ratio" in md and md.get("verified", True):
            ratios.append(md["ratio"])
            cmps.append(md.get("compress_MBps", 0))
            dms.append(md.get("decompress_MBps", 0))
            profile_bytes.append(md.get("external_profile_bytes", 0))
            n_ver += 1
    if ratios:
        agg[m] = {
            "n": len(ratios),
            "verified": n_ver,
            "median_ratio": float(np.median(ratios)),
            "geomean_ratio": float(np.exp(np.mean(np.log(ratios)))),
            "median_compress_MBps": float(np.median(cmps)),
            "median_decompress_MBps": float(np.median(dms)),
            "median_profile_bytes": int(np.median(profile_bytes)),
        }

# Sort by geomean
sorted_methods = sorted(agg.items(), key=lambda x: -x[1]["geomean_ratio"])
print(f"\n{'Method':<24} {'n':>4} {'median_ratio':>13} {'geomean':>8} {'comp MB/s':>10} {'dec MB/s':>10}  profile B")
for m, s in sorted_methods:
    print(f"  {m:<22} {s['n']:>4}  {s['median_ratio']:>10.4f}  {s['geomean_ratio']:>7.4f}  {s['median_compress_MBps']:>9.1f}  {s['median_decompress_MBps']:>9.1f}  {s['median_profile_bytes']}")

# Per-model breakdown
print("\nper model:")
by_model = defaultdict(list)
for r in rows:
    by_model[r["model"]].append(r)
for model, group in by_model.items():
    print(f"  {model}: n={len(group)}")
    for m in ["qb_k16", "qb_k8", "qb_k4", "qa", "zstd19_raw", "zstd19_bytegrouped", "zstd19_dict", "xz9", "brotli_q9", "zstd3_raw", "gzip6"]:
        ratios = [r["methods"][m]["ratio"] for r in group if r["methods"].get(m, {}).get("ratio")]
        if ratios:
            print(f"    {m}: median={np.median(ratios):.4f} geomean={np.exp(np.mean(np.log(ratios))):.4f}")

# Per-category
print("\nper category:")
by_cat = defaultdict(list)
for r in rows:
    by_cat[r.get("tensor_category", "?")].append(r)
for cat, group in by_cat.items():
    if len(group) < 5: continue
    print(f"  {cat}: n={len(group)}")
    for m in ["qa", "qb_k8", "zstd19_raw", "zstd19_bytegrouped", "xz9"]:
        ratios = [r["methods"][m]["ratio"] for r in group if r["methods"].get(m, {}).get("ratio")]
        if ratios:
            print(f"    {m}: median={np.median(ratios):.4f}")
