"""Direction Q benchmark — run all methods on test corpus Q4_K tensors.

Test models: Qwen2.5-7B, Llama-3.2-3B, Mistral-7B.
Methods: Q.A (sign+mag), Q.B (mixture K=4/8/16), Q.C (general baselines).
"""
import argparse
import hashlib
import json
import os
import sys
import time
from multiprocessing import Pool, cpu_count
from pathlib import Path

import numpy as np

sys.path.insert(0, "/mnt/data/track_b_v5/direction_q/01_gguf_parser")
sys.path.insert(0, "/mnt/data/track_b_v5/direction_q/03_method_quant_residual")
sys.path.insert(0, "/mnt/data/track_b_v5/direction_q/04_method_block_aware")
sys.path.insert(0, "/mnt/data/track_b_v5/direction_q/05_method_general_baseline")

from gguf_inspect import parse_manifest, extract_tensor_q4k, repack_nibbles, unpack_nibbles
from qa_residual import compress_q4k_tensor_qa, decompress_q4k_tensor_qa
from qb_block import QBProfile, compress_qb, decompress_qb
from general_baselines import bench_methods

TEST_GGUFS = [
    "/mnt/data/track_b_v5/corpus/gguf/Qwen2.5-7B/Qwen2.5-7B-Instruct-Q4_K_M.gguf",
    "/mnt/data/track_b_v5/corpus/gguf/Llama-3.2-3B/Llama-3.2-3B-Instruct-Q4_K_M.gguf",
    "/mnt/data/track_b_v5/corpus/gguf/Mistral-7B/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf",
]
QB_PROFILES = {
    4: "/mnt/data/track_b_v5/direction_q/04_method_block_aware/qb_k4.profile",
    8: "/mnt/data/track_b_v5/direction_q/04_method_block_aware/qb_k8.profile",
    16: "/mnt/data/track_b_v5/direction_q/04_method_block_aware/qb_k16.profile",
}
DICT_PATH = "/mnt/data/track_b_v5/direction_q/05_method_general_baseline/zstd_int4.dict"
OUT_JSONL = "/mnt/data/track_b_v5/results/direction_q/benchmark.jsonl"

_qb_profiles = None
_dict_blob = None

def init_worker():
    global _qb_profiles, _dict_blob
    _qb_profiles = {K: QBProfile.from_json(Path(p).read_text()) for K, p in QB_PROFILES.items()}
    _dict_blob = Path(DICT_PATH).read_bytes()


def bench_one(args):
    gguf_path, tensor_name, source_model, category = args
    try:
        d, dmin, scales6, _, nb = extract_tensor_q4k(gguf_path, tensor_name)
    except Exception as e:
        return {"path": gguf_path, "tensor": tensor_name, "error": str(e)}
    n_blocks = nb.shape[0]
    input_bytes = n_blocks * 144
    out = {
        "model": source_model, "tensor_name": tensor_name, "tensor_category": category,
        "n_blocks": n_blocks, "input_bytes": input_bytes, "methods": {},
    }
    # Q.A
    try:
        t0 = time.perf_counter()
        qa_blob = compress_q4k_tensor_qa(d, dmin, scales6, nb)
        c_t = time.perf_counter() - t0
        t0 = time.perf_counter()
        d2, dmin2, s62, nb2 = decompress_q4k_tensor_qa(qa_blob)
        d_t = time.perf_counter() - t0
        verified = (d.tobytes() == d2.tobytes() and dmin.tobytes() == dmin2.tobytes()
                    and np.array_equal(scales6, s62) and np.array_equal(nb, nb2))
        out["methods"]["qa"] = {
            "compressed_bytes": len(qa_blob),
            "external_profile_bytes": 0,
            "total_compressed_bytes": len(qa_blob),
            "ratio": input_bytes / len(qa_blob),
            "compress_MBps": (input_bytes / (1 << 20)) / max(c_t, 1e-9),
            "decompress_MBps": (input_bytes / (1 << 20)) / max(d_t, 1e-9),
            "verified": verified,
        }
    except Exception as e:
        out["methods"]["qa"] = {"error": str(e)}

    # Q.B for each K
    for K, profile in _qb_profiles.items():
        try:
            t0 = time.perf_counter()
            qb_blob = compress_qb(d, dmin, scales6, nb, profile)
            c_t = time.perf_counter() - t0
            t0 = time.perf_counter()
            d2, dmin2, s62, nb2 = decompress_qb(qb_blob, profile)
            d_t = time.perf_counter() - t0
            verified = (d.tobytes() == d2.tobytes() and dmin.tobytes() == dmin2.tobytes()
                        and np.array_equal(scales6, s62) and np.array_equal(nb, nb2))
            profile_bytes = os.path.getsize(QB_PROFILES[K])
            total = len(qb_blob) + profile_bytes
            out["methods"][f"qb_k{K}"] = {
                "compressed_bytes": len(qb_blob),
                "external_profile_bytes": profile_bytes,
                "total_compressed_bytes": total,
                "ratio": input_bytes / total,
                "compress_MBps": (input_bytes / (1 << 20)) / max(c_t, 1e-9),
                "decompress_MBps": (input_bytes / (1 << 20)) / max(d_t, 1e-9),
                "verified": verified,
            }
        except Exception as e:
            out["methods"][f"qb_k{K}"] = {"error": str(e)}

    # Q.C general baselines
    try:
        res = bench_methods(d, dmin, scales6, nb, _dict_blob)
        for k, v in res.items():
            if isinstance(v, dict):
                out["methods"][k] = v
    except Exception as e:
        out["methods"]["general_error"] = str(e)
    return out


def main():
    Path("/mnt/data/track_b_v5/results/direction_q").mkdir(parents=True, exist_ok=True)
    # Build task list
    tasks = []
    for gguf in TEST_GGUFS:
        if not os.path.exists(gguf):
            print(f"missing: {gguf}", flush=True); continue
        source_model = Path(gguf).name.replace(".gguf", "")
        manifest = parse_manifest(gguf)
        q4 = [r for r in manifest if r["dtype"] == "Q4_K" and r["n_elements"] >= 65536]
        for r in q4:
            tasks.append((gguf, r["name"], source_model, r["tensor_category"]))
    print(f"running on {len(tasks)} Q4_K test tensors", flush=True)
    n_workers = min(24, cpu_count())
    t0 = time.perf_counter()
    done = 0
    with open(OUT_JSONL, "w") as f, Pool(n_workers, initializer=init_worker) as pool:
        for row in pool.imap_unordered(bench_one, tasks, chunksize=1):
            f.write(json.dumps(row) + "\n"); f.flush()
            done += 1
            if done % 50 == 0 or done == len(tasks):
                print(f"  [{done}/{len(tasks)}] elapsed={time.perf_counter() - t0:.0f}s", flush=True)
    print(f"done in {time.perf_counter() - t0:.0f}s")


if __name__ == "__main__":
    main()
