#!/usr/bin/env python3
"""Benchmark mixture-CDF bf16_split on 290 bf16 test tensors."""
import hashlib
import json
import os
import subprocess
import sys
import time
import uuid
from multiprocessing import Pool, cpu_count
from pathlib import Path

sys.path.insert(0, "/mnt/data/scape")
import numpy as np
from scape import container as _scape_container

OUT_DIR = Path("/mnt/data/track_b_v3/method/03_bench_mixture_k8")
OUT_DIR.mkdir(parents=True, exist_ok=True)
WORK_BASE = Path("/mnt/data/track_b_v3/method/work")
WORK_BASE.mkdir(parents=True, exist_ok=True)

PY = "/mnt/data/track_b_ai_weights/venv/bin/python3"
MIX_SCRIPT = "/mnt/data/track_b_v3/method/mixture_cdf.py"
PROFILE = "/mnt/data/track_b_v3/method/trained/bf16_mixture_k8.json"
PROFILE_SIZE = os.path.getsize(PROFILE)
TEST_JSONL = "/mnt/data/track_b_ai_weights/manifests/test.jsonl"


def sha256(path):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            b = f.read(1 << 20)
            if not b: break
            h.update(b)
    return h.hexdigest()


def process_tensor(entry):
    inp = entry["path"]
    work_dir = WORK_BASE / uuid.uuid4().hex[:8]
    work_dir.mkdir(parents=True, exist_ok=True)
    try:
        sz = os.path.getsize(inp)
        comp = work_dir / "mix.scape"
        rest = work_dir / "mix.out"
        t0 = time.perf_counter()
        subprocess.run([PY, MIX_SCRIPT, "compress", "--input", inp, "--output", str(comp),
                        "--pretrained", PROFILE], check=True, capture_output=True, timeout=3600)
        c = time.perf_counter() - t0
        t0 = time.perf_counter()
        subprocess.run([PY, MIX_SCRIPT, "decompress", "--input", str(comp), "--output", str(rest),
                        "--pretrained", PROFILE], check=True, capture_output=True, timeout=3600)
        d = time.perf_counter() - t0
        payload = comp.stat().st_size
        total = payload + PROFILE_SIZE
        verified = sha256(rest) == sha256(inp)
        try:
            _, hdr, _ = _scape_container.read(str(comp))
            tags = hdr.get("chunk_tags", [])
        except Exception:
            tags = []
        return {
            "path": inp,
            "tensor_category": entry.get("tensor_category"),
            "source_model": entry.get("source_model"),
            "input_bytes": sz,
            "compressed_payload_bytes": payload,
            "external_profile_bytes": PROFILE_SIZE,
            "total_compressed_bytes": total,
            "ratio": sz / total,
            "compress_seconds": c, "decompress_seconds": d,
            "compress_MBps": (sz / (1 << 20)) / c if c > 0 else 0,
            "decompress_MBps": (sz / (1 << 20)) / d if d > 0 else 0,
            "verified": verified,
            "chunk_tags": tags,
        }
    except Exception as e:
        return {"path": inp, "error": str(e)}
    finally:
        for f in work_dir.glob("*"):
            try: f.unlink()
            except: pass
        try: work_dir.rmdir()
        except: pass


def main():
    entries = []
    with open(TEST_JSONL) as f:
        for line in f:
            e = json.loads(line)
            if e.get("kind") == "raw" and e.get("dtype") == "bf16":
                entries.append(e)
    entries.sort(key=lambda e: e["size"])
    print(f"[bench] {len(entries)} bf16 tensors, profile={PROFILE_SIZE} bytes", flush=True)
    n_workers = min(24, cpu_count())

    out_path = OUT_DIR / "results.jsonl"
    t0 = time.perf_counter()
    done = 0; ratios = []
    with open(out_path, "w") as f, Pool(n_workers) as pool:
        for row in pool.imap_unordered(process_tensor, entries, chunksize=1):
            f.write(json.dumps(row) + "\n"); f.flush()
            done += 1
            if "ratio" in row: ratios.append(row["ratio"])
            if done % 20 == 0 or done == len(entries):
                med = float(np.median(ratios)) if ratios else 0.0
                elapsed = time.perf_counter() - t0
                rate = done / elapsed if elapsed > 0 else 0
                eta = (len(entries) - done) / rate if rate > 0 else 0
                print(f"  [{done}/{len(entries)}] med ratio={med:.4f}  ({rate:.2f} t/s, eta {eta:.0f}s)", flush=True)
    print(f"[bench] done in {time.perf_counter() - t0:.0f}s", flush=True)


if __name__ == "__main__":
    main()
