#!/usr/bin/env python3
"""Benchmark bf16_split vs scape_per_file (fp16_xor) on all 290 bf16 test tensors.

Per tensor, measure:
- scape_per_file (fp16_xor)  — v2's current "SCAPE bf16"
- bf16_split                 — our new bf16-native profile
- zstd-19 raw                — reference baseline (no profile bytes)
- zstd-19 bytegrouped        — reference baseline (atlas measured this)

For every method record compress / decompress wall, byte-exact roundtrip,
compressed size.
"""
import hashlib
import json
import os
import subprocess
import sys
import time
import uuid
from multiprocessing import Pool, cpu_count
from pathlib import Path

import numpy as np
import zstandard as zstd

OUT_DIR = Path("/mnt/data/track_b_v3/method/01_bench_perfile")
OUT_DIR.mkdir(parents=True, exist_ok=True)
WORK_BASE = Path("/mnt/data/track_b_v3/method/work")
WORK_BASE.mkdir(parents=True, exist_ok=True)

SCAPE = "/mnt/data/scape/bin/scape"
TEST_JSONL = "/mnt/data/track_b_ai_weights/manifests/test.jsonl"


def sha256(path):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            b = f.read(1 << 20)
            if not b: break
            h.update(b)
    return h.hexdigest()


def run_scape(inp, profile, work_dir):
    comp = work_dir / f"{profile}.scape"
    rest = work_dir / f"{profile}.out"
    t0 = time.perf_counter()
    subprocess.run([SCAPE, "compress", "--input", inp, "--output", str(comp), "--profile", profile],
                   check=True, capture_output=True, timeout=3600)
    c = time.perf_counter() - t0
    t0 = time.perf_counter()
    subprocess.run([SCAPE, "decompress", "--input", str(comp), "--output", str(rest)],
                   check=True, capture_output=True, timeout=3600)
    d = time.perf_counter() - t0
    sz = comp.stat().st_size
    verified = sha256(rest) == sha256(inp)
    return {"compressed_bytes": sz, "compress_seconds": c, "decompress_seconds": d, "verified": verified}


def run_zstd_raw(inp):
    data = Path(inp).read_bytes()
    t0 = time.perf_counter()
    comp = zstd.ZstdCompressor(level=19).compress(data)
    c = time.perf_counter() - t0
    t0 = time.perf_counter()
    rec = zstd.ZstdDecompressor().decompress(comp)
    d = time.perf_counter() - t0
    verified = rec == data
    return {"compressed_bytes": len(comp), "compress_seconds": c, "decompress_seconds": d, "verified": verified}


def run_zstd_bytegrouped(inp):
    data = np.fromfile(inp, dtype="<u2")
    byte0 = (data & 0xFF).astype(np.uint8).tobytes()
    byte1 = (data >> 8).astype(np.uint8).tobytes()
    bg = byte0 + byte1
    t0 = time.perf_counter()
    comp = zstd.ZstdCompressor(level=19).compress(bg)
    c = time.perf_counter() - t0
    n = data.size
    t0 = time.perf_counter()
    rec = zstd.ZstdDecompressor().decompress(comp)
    b0 = np.frombuffer(rec[:n], dtype=np.uint8)
    b1 = np.frombuffer(rec[n:], dtype=np.uint8)
    rec_arr = (b1.astype(np.uint16) << 8) | b0.astype(np.uint16)
    d = time.perf_counter() - t0
    verified = np.array_equal(rec_arr, data)
    return {"compressed_bytes": len(comp), "compress_seconds": c, "decompress_seconds": d, "verified": verified}


def process_tensor(entry):
    inp = entry["path"]
    work_dir = WORK_BASE / uuid.uuid4().hex[:8]
    work_dir.mkdir(parents=True, exist_ok=True)
    try:
        sz = os.path.getsize(inp)
        rec = {
            "path": inp,
            "tensor_name": entry.get("tensor_name"),
            "tensor_category": entry.get("tensor_category"),
            "source_model": entry.get("source_model"),
            "input_bytes": sz,
            "methods": {},
        }
        # scape_per_file (fp16_xor) — skipped if file > 256 MiB (slow) but our largest bf16 is 272MiB which is just over
        # The v2 cap is 16 MiB; with our atlas we already saw fp16_xor takes ~5min for 22MiB.
        # For a 272 MiB file fp16_xor would take ~hour each. Skip if > 32 MiB.
        if sz <= 32 * (1 << 20):
            try:
                rec["methods"]["scape_per_file"] = run_scape(inp, "fp16_xor", work_dir)
            except Exception as e:
                rec["methods"]["scape_per_file"] = {"error": str(e)}
        else:
            rec["methods"]["scape_per_file"] = {"skipped": "size > 32 MiB"}
        # bf16_split
        try:
            rec["methods"]["bf16_split"] = run_scape(inp, "bf16_split", work_dir)
        except Exception as e:
            rec["methods"]["bf16_split"] = {"error": str(e)}
        # zstd reference
        try:
            rec["methods"]["zstd19_raw"] = run_zstd_raw(inp)
        except Exception as e:
            rec["methods"]["zstd19_raw"] = {"error": str(e)}
        try:
            rec["methods"]["zstd19_bytegrouped"] = run_zstd_bytegrouped(inp)
        except Exception as e:
            rec["methods"]["zstd19_bytegrouped"] = {"error": str(e)}
        # compute ratios
        for name, m in rec["methods"].items():
            if "compressed_bytes" in m:
                m["ratio"] = sz / m["compressed_bytes"]
                m["compress_MBps"] = (sz / (1 << 20)) / m["compress_seconds"] if m["compress_seconds"] > 0 else 0
                m["decompress_MBps"] = (sz / (1 << 20)) / m["decompress_seconds"] if m["decompress_seconds"] > 0 else 0
        return rec
    finally:
        # cleanup
        for f in work_dir.glob("*"):
            try: f.unlink()
            except: pass
        try: work_dir.rmdir()
        except: pass


def main():
    entries = []
    with open(TEST_JSONL) as f:
        for line in f:
            e = json.loads(line)
            if e.get("kind") == "raw" and e.get("dtype") == "bf16":
                entries.append(e)
    print(f"[bench] {len(entries)} bf16 raw tensors", flush=True)
    entries.sort(key=lambda e: e["size"])

    n_workers = min(24, cpu_count())  # 24 — bf16_split is fast, scape_per_file is single-thread CPU heavy
    print(f"[bench] {n_workers} workers", flush=True)

    out_path = OUT_DIR / "results.jsonl"
    t0 = time.perf_counter()
    done = 0
    bf16_ratios = []
    pf_ratios = []
    with open(out_path, "w") as f, Pool(n_workers) as pool:
        for row in pool.imap_unordered(process_tensor, entries, chunksize=1):
            f.write(json.dumps(row) + "\n")
            f.flush()
            done += 1
            m = row.get("methods", {})
            if "ratio" in m.get("bf16_split", {}):
                bf16_ratios.append(m["bf16_split"]["ratio"])
            if "ratio" in m.get("scape_per_file", {}):
                pf_ratios.append(m["scape_per_file"]["ratio"])
            if done % 10 == 0 or done == len(entries):
                med_b = float(np.median(bf16_ratios)) if bf16_ratios else 0.0
                med_p = float(np.median(pf_ratios)) if pf_ratios else 0.0
                elapsed = time.perf_counter() - t0
                rate = done / elapsed if elapsed > 0 else 0
                eta = (len(entries) - done) / rate if rate > 0 else 0
                print(f"  [{done}/{len(entries)}] med ratio bf16_split={med_b:.4f} per_file={med_p:.4f}  "
                      f"({rate:.2f} t/s, eta {eta:.0f}s)", flush=True)
    print(f"[bench] done in {time.perf_counter() - t0:.0f}s. wrote {out_path}", flush=True)


if __name__ == "__main__":
    main()
