"""Method Q.C — General-purpose lossless baselines on Q4_K raw bytes.

For each Q4_K tensor, take its raw on-disk bytes (n_blocks * 144 bytes) and run:
  - zstd-3, zstd-19 (raw)
  - zstd-19 bytegrouped (separate even/odd nibble streams within the nibble byte region)
  - zstd-19 dict-trained
  - xz-9
  - brotli-q9
  - gzip-6

All include external profile bytes (dict) in the ratio per fairness contract.
"""
import argparse
import gzip
import hashlib
import json
import lzma
import os
import struct
import sys
import time
from pathlib import Path

import brotli
import numpy as np
import zstandard as zstd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "01_gguf_parser"))
from gguf_inspect import extract_tensor_q4k


def bench_methods(d, dmin, scales6, nibble_blocks, dict_blob: bytes | None = None):
    """Run all baselines on the raw bytes of a Q4_K tensor. Returns dict of method -> result."""
    n_blocks = nibble_blocks.shape[0]
    # Reconstruct raw bytes of the on-disk tensor
    raw = bytearray(n_blocks * 144)
    for i in range(n_blocks):
        raw[i * 144:i * 144 + 2] = d[i].tobytes()
        raw[i * 144 + 2:i * 144 + 4] = dmin[i].tobytes()
        raw[i * 144 + 4:i * 144 + 16] = scales6[i].tobytes()
        raw[i * 144 + 16:i * 144 + 144] = nibble_blocks[i].tobytes()
    raw = bytes(raw)
    input_bytes = len(raw)
    out = {"input_bytes": input_bytes, "n_blocks": n_blocks}

    def bench(label, comp_fn, decomp_fn, profile_bytes=0):
        t0 = time.perf_counter()
        comp = comp_fn(raw)
        c_t = time.perf_counter() - t0
        t0 = time.perf_counter()
        restored = decomp_fn(comp)
        d_t = time.perf_counter() - t0
        verified = (restored == raw)
        total = len(comp) + profile_bytes
        out[label] = {
            "compressed_payload_bytes": len(comp),
            "external_profile_bytes": profile_bytes,
            "compressed_bytes": total,
            "ratio": input_bytes / total,
            "compress_MBps": (input_bytes / (1 << 20)) / max(c_t, 1e-9),
            "decompress_MBps": (input_bytes / (1 << 20)) / max(d_t, 1e-9),
            "verified": verified,
        }

    bench("zstd3_raw", lambda x: zstd.ZstdCompressor(level=3).compress(x), lambda x: zstd.ZstdDecompressor().decompress(x))
    bench("zstd19_raw", lambda x: zstd.ZstdCompressor(level=19).compress(x), lambda x: zstd.ZstdDecompressor().decompress(x))

    # Bytegrouped: split each 144-byte block into (d, dmin, scales6, nibbles) before compression
    def bg_compress(raw_bytes):
        arr = np.frombuffer(raw_bytes, dtype=np.uint8).reshape(-1, 144)
        d_bytes = arr[:, 0:4].tobytes()
        scales6_bytes = arr[:, 4:16].tobytes()
        # For the nibble bytes, split low/high nibble streams
        nb = arr[:, 16:144]
        low = (nb & 0x0F).tobytes()
        high = ((nb >> 4) & 0x0F).tobytes()
        return zstd.ZstdCompressor(level=19).compress(d_bytes + scales6_bytes + low + high)
    def bg_decompress(comp):
        dctx = zstd.ZstdDecompressor()
        blob = dctx.decompress(comp)
        n = len(blob) // (4 + 12 + 128 + 128)
        # n_blocks; offsets
        d_size = n * 4
        s6_size = n * 12
        nib_size = n * 128
        pos = 0
        d_bytes = blob[pos:pos+d_size]; pos += d_size
        s6_bytes = blob[pos:pos+s6_size]; pos += s6_size
        low = blob[pos:pos+nib_size]; pos += nib_size
        high = blob[pos:pos+nib_size]; pos += nib_size
        # Rebuild 144-byte rows
        out_arr = np.zeros((n, 144), dtype=np.uint8)
        out_arr[:, 0:4] = np.frombuffer(d_bytes, dtype=np.uint8).reshape(n, 4)
        out_arr[:, 4:16] = np.frombuffer(s6_bytes, dtype=np.uint8).reshape(n, 12)
        low_arr = np.frombuffer(low, dtype=np.uint8).reshape(n, 128)
        high_arr = np.frombuffer(high, dtype=np.uint8).reshape(n, 128)
        out_arr[:, 16:144] = (low_arr & 0x0F) | ((high_arr & 0x0F) << 4)
        return out_arr.tobytes()
    bench("zstd19_bytegrouped", bg_compress, bg_decompress)

    # dict-trained zstd
    if dict_blob is not None:
        zdict = zstd.ZstdCompressionDict(dict_blob)
        cctx_d = zstd.ZstdCompressor(level=19, dict_data=zdict)
        dctx_d = zstd.ZstdDecompressor(dict_data=zdict)
        bench("zstd19_dict", lambda x: cctx_d.compress(x), lambda x: dctx_d.decompress(x), profile_bytes=len(dict_blob))

    bench("xz9", lambda x: lzma.compress(x, preset=9), lambda x: lzma.decompress(x))
    bench("brotli_q9", lambda x: brotli.compress(x, quality=9), lambda x: brotli.decompress(x))
    bench("gzip6", lambda x: gzip.compress(x, compresslevel=6), lambda x: gzip.decompress(x))
    return out


def self_test():
    rng = np.random.RandomState(0)
    n_blocks = 100
    d = rng.uniform(0.01, 0.5, size=n_blocks).astype(np.float16)
    dmin = rng.uniform(-0.5, 0.5, size=n_blocks).astype(np.float16)
    scales6 = rng.randint(0, 256, size=(n_blocks, 12)).astype(np.uint8)
    nibbles = (8 + rng.laplace(scale=1.5, size=n_blocks * 256).astype(np.int32)).clip(0, 15).astype(np.uint8)
    from gguf_inspect import repack_nibbles
    nb = repack_nibbles(nibbles).reshape(n_blocks, 128)
    res = bench_methods(d, dmin, scales6, nb)
    for k, v in res.items():
        if isinstance(v, dict):
            print(f"  {k}: ratio={v['ratio']:.4f}  verified={v['verified']}")
    assert all(v["verified"] for v in res.values() if isinstance(v, dict))
    print("OK: general_baselines self-test passed")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--self-test", action="store_true")
    ap.add_argument("--input", help="GGUF file")
    ap.add_argument("--tensor", help="tensor name")
    ap.add_argument("--dict", help="zstd dictionary file")
    args = ap.parse_args()
    if args.self_test:
        self_test(); return
    dict_blob = Path(args.dict).read_bytes() if args.dict else None
    d, dmin, scales6, _, nb = extract_tensor_q4k(args.input, args.tensor)
    res = bench_methods(d, dmin, scales6, nb, dict_blob)
    print(json.dumps(res, indent=2))


if __name__ == "__main__":
    main()
