#!/usr/bin/env python3
"""Benchmark mixture-CDF for given K values."""
import argparse
import hashlib
import json
import os
import subprocess
import sys
import time
import uuid
from multiprocessing import Pool, cpu_count
from pathlib import Path

sys.path.insert(0, "/mnt/data/scape")
import numpy as np

PY = "/mnt/data/track_b_ai_weights/venv/bin/python3"
MIX_SCRIPT = "/mnt/data/track_b_v3/method/mixture_cdf.py"
TEST_JSONL = "/mnt/data/track_b_ai_weights/manifests/test.jsonl"
WORK_BASE = Path("/mnt/data/track_b_v3/method/work")
WORK_BASE.mkdir(parents=True, exist_ok=True)


def sha256(path):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            b = f.read(1 << 20)
            if not b: break
            h.update(b)
    return h.hexdigest()


_profile_path = None
_profile_size = None
def init_worker(profile_path):
    global _profile_path, _profile_size
    _profile_path = profile_path
    _profile_size = os.path.getsize(profile_path)


def process_tensor(entry):
    inp = entry["path"]
    work_dir = WORK_BASE / uuid.uuid4().hex[:8]
    work_dir.mkdir(parents=True, exist_ok=True)
    try:
        sz = os.path.getsize(inp)
        comp = work_dir / "mix.scape"
        rest = work_dir / "mix.out"
        t0 = time.perf_counter()
        subprocess.run([PY, MIX_SCRIPT, "compress", "--input", inp, "--output", str(comp),
                        "--pretrained", _profile_path], check=True, capture_output=True, timeout=3600)
        c = time.perf_counter() - t0
        t0 = time.perf_counter()
        subprocess.run([PY, MIX_SCRIPT, "decompress", "--input", str(comp), "--output", str(rest),
                        "--pretrained", _profile_path], check=True, capture_output=True, timeout=3600)
        d = time.perf_counter() - t0
        payload = comp.stat().st_size
        total = payload + _profile_size
        verified = sha256(rest) == sha256(inp)
        return {
            "path": inp,
            "tensor_category": entry.get("tensor_category"),
            "source_model": entry.get("source_model"),
            "input_bytes": sz,
            "compressed_payload_bytes": payload,
            "external_profile_bytes": _profile_size,
            "total_compressed_bytes": total,
            "ratio": sz / total,
            "compress_seconds": c, "decompress_seconds": d,
            "verified": verified,
        }
    except Exception as e:
        return {"path": inp, "error": str(e)}
    finally:
        for f in work_dir.glob("*"):
            try: f.unlink()
            except: pass
        try: work_dir.rmdir()
        except: pass


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--profile", required=True)
    ap.add_argument("--out-dir", required=True)
    args = ap.parse_args()
    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    entries = []
    with open(TEST_JSONL) as f:
        for line in f:
            e = json.loads(line)
            if e.get("kind") == "raw" and e.get("dtype") == "bf16":
                entries.append(e)
    entries.sort(key=lambda e: e["size"])
    print(f"[bench] profile={args.profile} ({os.path.getsize(args.profile)} B), {len(entries)} tensors", flush=True)
    n_workers = min(24, cpu_count())
    out_path = out_dir / "results.jsonl"
    t0 = time.perf_counter()
    done = 0; ratios = []
    with open(out_path, "w") as f, Pool(n_workers, initializer=init_worker, initargs=(args.profile,)) as pool:
        for row in pool.imap_unordered(process_tensor, entries, chunksize=1):
            f.write(json.dumps(row) + "\n"); f.flush()
            done += 1
            if "ratio" in row: ratios.append(row["ratio"])
            if done % 30 == 0 or done == len(entries):
                med = float(np.median(ratios)) if ratios else 0.0
                elapsed = time.perf_counter() - t0
                print(f"  [{done}/{len(entries)}] med={med:.4f}  ({done/elapsed:.2f} t/s)", flush=True)
    print(f"done in {time.perf_counter() - t0:.0f}s", flush=True)


if __name__ == "__main__":
    main()
