"""Method Q.B — Block-aware mixture-CDF coding for Q4_K_M int4 nibbles.

Strategy:
  - Each Q4_K super-block holds 256 nibbles (= 8 sub-blocks of 32 each)
  - Cluster super-blocks into K canonical CDFs (the 16-symbol nibble distribution)
  - Per super-block: 1 cluster ID (log2(K) bits) + range-coded 256 nibbles
  - Profile stored externally: K × 16 ints per row

Train phase:
  - Read super-blocks from training corpus, compute per-block nibble PMF
  - K-means cluster the PMFs (each PMF is a 16-d feature vector)
  - Store K centroid PMFs as the profile

Compress phase:
  - For each super-block: compute its PMF, assign to nearest cluster by KL divergence
  - Range-code 256 nibbles against that cluster's CDF
  - Output: header (n_blocks) + per-block cluster IDs (1 byte each, K<=256) + range-coded payloads
  - Plus the scale/min metadata stream from Q.A's design (zstd-19 byte-grouped)

External profile bytes counted per fairness contract.
"""
import argparse
import hashlib
import json
import os
import struct
import sys
import time
from multiprocessing import Pool, cpu_count
from pathlib import Path

import numpy as np
import zstandard as zstd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "01_gguf_parser"))
from gguf_inspect import extract_tensor_q4k, parse_manifest, repack_nibbles, unpack_nibbles

import constriction

_MAGIC = b"QB01"
_PRECISION = 12


def per_block_pmf(nibble_blocks):
    """nibble_blocks: (n_blocks, 128) uint8.  Returns (n_blocks, 16) float64 PMF."""
    n_blocks = nibble_blocks.shape[0]
    nibbles = unpack_nibbles(nibble_blocks).reshape(n_blocks, 256)
    # bincount per row, vectorised
    pmf = np.zeros((n_blocks, 16), dtype=np.float64)
    for v in range(16):
        pmf[:, v] = (nibbles == v).sum(axis=1)
    pmf /= 256.0
    return pmf


def kmeans_simple(features, K, n_iter=30, seed=0):
    """Lightweight K-means on small dim features. Returns (centroids, labels)."""
    from sklearn.cluster import KMeans
    km = KMeans(n_clusters=K, random_state=seed, n_init=4, max_iter=n_iter).fit(features)
    return km.cluster_centers_, km.labels_.astype(np.int64)


def assign_clusters(features, centroids):
    """For each row in features, pick nearest centroid (Euclidean over PMF; cheap proxy)."""
    # Use cosine of features-centroids? Simpler: KL is good but Euclidean is fast and OK on small PMFs.
    # features: (n, 16), centroids: (K, 16). Compute Euclidean distance squared.
    # Cross-entropy gives better assignment for entropy coding.
    K = centroids.shape[0]
    # H(p, q) = -sum p * log2(q + eps)
    log_c = np.log2(centroids + 1e-30)  # (K, 16)
    # cross entropy: features @ -log_c.T = (n, K)
    ce = -features @ log_c.T
    return ce.argmin(axis=1)


def _smooth_pmf_to_cdf_int(pmf, precision=_PRECISION):
    q = np.round(pmf * (1 << precision)).astype(np.int64)
    q = np.maximum(q, 1)
    cdf_norm = q.astype(np.float64) / q.sum()
    return [int(x) for x in q.tolist()], cdf_norm


class QBProfile:
    def __init__(self, K, cluster_pmfs_int):
        self.K = K
        self.cluster_pmfs_int = cluster_pmfs_int   # list of K lists of 16 ints
        self.cluster_pmfs = np.array([
            np.array(p, dtype=np.float64) / np.sum(p) for p in cluster_pmfs_int
        ])

    def to_json(self):
        return json.dumps({"K": self.K, "pmfs": self.cluster_pmfs_int}, separators=(",", ":"))

    @classmethod
    def from_json(cls, s):
        d = json.loads(s)
        return cls(d["K"], d["pmfs"])


def cmd_train(args):
    """Iterate training corpus GGUFs, extract all Q4_K super-blocks' PMFs, K-means."""
    print(f"=== Q.B train (K={args.k}) ===", flush=True)
    all_pmfs = []
    cap = args.max_blocks
    for gguf in args.corpus:
        m = parse_manifest(gguf)
        q4 = [r for r in m if r["dtype"] == "Q4_K" and r["n_elements"] >= 65536]
        print(f"  {gguf}: {len(q4)} Q4_K tensors")
        for r in q4:
            _, _, _, _, nb = extract_tensor_q4k(gguf, r["name"])
            pmf = per_block_pmf(nb)
            all_pmfs.append(pmf)
            tot = sum(p.shape[0] for p in all_pmfs)
            if tot >= cap:
                print(f"  reached cap of {cap} blocks", flush=True)
                break
        if sum(p.shape[0] for p in all_pmfs) >= cap:
            break
    X = np.concatenate(all_pmfs, axis=0)
    if X.shape[0] > cap:
        X = X[:cap]
    print(f"  K-means on {X.shape[0]} block PMFs", flush=True)
    centroids, labels = kmeans_simple(X, args.k)
    # Convert to int CDFs
    pmfs_int = []
    for c in centroids:
        ints, _ = _smooth_pmf_to_cdf_int(np.clip(c, 1e-9, None))
        pmfs_int.append(ints)
    profile = QBProfile(args.k, pmfs_int)
    Path(args.output).write_text(profile.to_json())
    print(f"  wrote profile to {args.output} ({os.path.getsize(args.output)} bytes)")


def _bytegrouped_fp16(arr):
    raw = arr.tobytes()
    a = np.frombuffer(raw, dtype=np.uint8).reshape(-1, 2)
    return a[:, 0].tobytes() + a[:, 1].tobytes()


def _ungroup_fp16(blob, n):
    lo = np.frombuffer(blob[:n], dtype=np.uint8)
    hi = np.frombuffer(blob[n:2*n], dtype=np.uint8)
    out = np.empty(2 * n, dtype=np.uint8)
    out[0::2] = lo; out[1::2] = hi
    return out.tobytes()


def compress_qb(d, dmin, scales6, nibble_blocks, profile: QBProfile):
    n_blocks = nibble_blocks.shape[0]
    nibbles = unpack_nibbles(nibble_blocks).reshape(n_blocks, 256)
    # Per-block PMF
    pmfs = per_block_pmf(nibble_blocks)
    # Assign clusters by min cross-entropy
    labels = assign_clusters(pmfs, profile.cluster_pmfs)
    # Encode nibbles cluster-by-cluster (grouped) to use a single CDF per group
    # Constriction can handle this via ans.encode_reverse with a Categorical(probabilities=cdf)
    payloads = []
    for k in range(profile.K):
        mask = (labels == k)
        n_k = int(mask.sum())
        if n_k == 0:
            payloads.append(b"")
            continue
        syms = nibbles[mask].reshape(-1).astype(np.int32)
        cdf = profile.cluster_pmfs[k]
        model = constriction.stream.model.Categorical(probabilities=cdf, perfect=False)
        ans = constriction.stream.stack.AnsCoder()
        ans.encode_reverse(syms, model)
        payloads.append(ans.get_compressed().astype(np.uint32).tobytes())
    # Compress metadata: scales (d, dmin, scales6)
    cctx = zstd.ZstdCompressor(level=19)
    d_blob = cctx.compress(_bytegrouped_fp16(d.astype("<f2")))
    dmin_blob = cctx.compress(_bytegrouped_fp16(dmin.astype("<f2")))
    s6_blob = cctx.compress(scales6.tobytes())
    # Cluster ID stream (1 byte per block, K<=256)
    cluster_blob = cctx.compress(labels.astype(np.uint8).tobytes())
    # Container
    parts = []
    parts.append(struct.pack("<4sII", _MAGIC, n_blocks, profile.K))
    parts.append(struct.pack("<IIII", len(d_blob), len(dmin_blob), len(s6_blob), len(cluster_blob)))
    parts.append(struct.pack(f"<{profile.K}I", *[len(p) for p in payloads]))
    parts.append(d_blob); parts.append(dmin_blob); parts.append(s6_blob); parts.append(cluster_blob)
    for p in payloads: parts.append(p)
    return b"".join(parts)


def decompress_qb(blob, profile: QBProfile):
    magic, n_blocks, K = struct.unpack("<4sII", blob[:12])
    assert magic == _MAGIC, f"bad magic"
    pos = 12
    d_len, dmin_len, s6_len, cluster_len = struct.unpack("<IIII", blob[pos:pos+16]); pos += 16
    payload_lens = struct.unpack(f"<{K}I", blob[pos:pos + 4*K]); pos += 4*K
    dctx = zstd.ZstdDecompressor()
    d_raw = dctx.decompress(blob[pos:pos+d_len]); pos += d_len
    dmin_raw = dctx.decompress(blob[pos:pos+dmin_len]); pos += dmin_len
    s6_raw = dctx.decompress(blob[pos:pos+s6_len]); pos += s6_len
    cluster_raw = dctx.decompress(blob[pos:pos+cluster_len]); pos += cluster_len
    d_bytes = _ungroup_fp16(d_raw, n_blocks)
    dmin_bytes = _ungroup_fp16(dmin_raw, n_blocks)
    d = np.frombuffer(d_bytes, dtype="<f2").copy()
    dmin = np.frombuffer(dmin_bytes, dtype="<f2").copy()
    scales6 = np.frombuffer(s6_raw, dtype=np.uint8).reshape(n_blocks, 12).copy()
    labels = np.frombuffer(cluster_raw, dtype=np.uint8).astype(np.int64)
    # Decode nibbles per cluster
    decoded_nibbles_by_cluster = {}
    for k in range(profile.K):
        n_k = int((labels == k).sum())
        payload = blob[pos:pos + payload_lens[k]]
        pos += payload_lens[k]
        if n_k == 0:
            decoded_nibbles_by_cluster[k] = np.array([], dtype=np.int32)
            continue
        cdf = profile.cluster_pmfs[k]
        model = constriction.stream.model.Categorical(probabilities=cdf, perfect=False)
        comp = np.frombuffer(payload, dtype=np.uint32).copy()
        ans = constriction.stream.stack.AnsCoder(comp)
        decoded_nibbles_by_cluster[k] = ans.decode(model, n_k * 256).astype(np.uint8)
    # Reassemble
    out_nibbles = np.zeros((n_blocks, 256), dtype=np.uint8)
    cursor = {k: 0 for k in range(profile.K)}
    for bi, k in enumerate(labels):
        ki = int(k)
        sub = decoded_nibbles_by_cluster[ki][cursor[ki]:cursor[ki] + 256]
        cursor[ki] += 256
        out_nibbles[bi] = sub
    # Pack back to nibble_blocks
    nibble_flat = out_nibbles.reshape(-1)
    nibble_blocks = repack_nibbles(nibble_flat).reshape(n_blocks, 128)
    return d, dmin, scales6, nibble_blocks


def self_test():
    rng = np.random.RandomState(0)
    n_blocks = 100
    d = rng.uniform(0.01, 0.5, size=n_blocks).astype(np.float16)
    dmin = rng.uniform(-0.5, 0.5, size=n_blocks).astype(np.float16)
    scales6 = rng.randint(0, 256, size=(n_blocks, 12)).astype(np.uint8)
    # Mix two distributions
    nibbles_a = (8 + rng.laplace(scale=1.2, size=50 * 256).astype(np.int32)).clip(0, 15).astype(np.uint8)
    nibbles_b = (8 + rng.laplace(scale=2.0, size=50 * 256).astype(np.int32)).clip(0, 15).astype(np.uint8)
    nibbles = np.concatenate([nibbles_a, nibbles_b])
    nibble_blocks = repack_nibbles(nibbles).reshape(n_blocks, 128)
    # Build a fake profile with K=4
    pmfs_int = []
    for _ in range(4):
        v = rng.dirichlet(np.ones(16) * 2.0)
        ints, _ = _smooth_pmf_to_cdf_int(v)
        pmfs_int.append(ints)
    profile = QBProfile(4, pmfs_int)
    blob = compress_qb(d, dmin, scales6, nibble_blocks, profile)
    d2, dmin2, s62, nb2 = decompress_qb(blob, profile)
    assert d.tobytes() == d2.tobytes()
    assert dmin.tobytes() == dmin2.tobytes()
    assert np.array_equal(scales6, s62)
    assert np.array_equal(nibble_blocks, nb2)
    print(f"OK: qb_block self-test passed. n_blocks={n_blocks}, K=4, ratio={n_blocks*144/len(blob):.3f}")


def main():
    ap = argparse.ArgumentParser()
    sub = ap.add_subparsers(dest="cmd", required=True)
    p = sub.add_parser("train")
    p.add_argument("--corpus", nargs="+", required=True)
    p.add_argument("--k", type=int, default=8)
    p.add_argument("--output", required=True)
    p.add_argument("--max-blocks", type=int, default=200_000)
    p = sub.add_parser("compress")
    p.add_argument("--input", required=True)
    p.add_argument("--tensor", required=True)
    p.add_argument("--profile", required=True)
    p = sub.add_parser("self-test")
    args = ap.parse_args()
    if args.cmd == "train":
        cmd_train(args)
    elif args.cmd == "compress":
        profile = QBProfile.from_json(Path(args.profile).read_text())
        d, dmin, scales6, _, nb = extract_tensor_q4k(args.input, args.tensor)
        blob = compress_qb(d, dmin, scales6, nb, profile)
        # Verify
        d2, dmin2, s62, nb2 = decompress_qb(blob, profile)
        assert np.array_equal(nb, nb2)
        profile_size = os.path.getsize(args.profile)
        orig = nb.shape[0] * 144
        total = len(blob) + profile_size
        print(f"Q.B K={profile.K} on {args.tensor}: orig={orig} container={len(blob)} profile={profile_size} ratio={orig/total:.4f}")
    elif args.cmd == "self-test":
        self_test()


if __name__ == "__main__":
    main()
