"""Mixture-CDF bf16_split.

Train K CDFs by clustering training files in (SE9 PMF, M7 PMF) feature space,
fit one GorillaBf16Predictor per cluster. At compress time, for each chunk
pick the CDF with lowest cross-entropy (O(alphabet) per candidate, no extra
encode passes), then encode once with that CDF. Per-chunk overhead: 1 byte tag.

Decode is parameterless except for the external profile + per-chunk tag.

For honest accounting: external profile JSON contains all K CDFs; the
container does NOT embed params (audit fix).
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path

import numpy as np
from sklearn.cluster import KMeans

sys.path.insert(0, "/mnt/data/scape")
from scape import container
from scape.predictors.gorilla_bf16 import GorillaBf16Predictor, _split_bf16, _join_bf16
from scape.predictors.base import Stream

_CHUNK_SIZE = 4 * 1024 * 1024 // 2  # u16 elements


def file_features(path):
    """Return concatenated (SE9 PMF, M7 PMF) feature vector of length 640."""
    arr = np.fromfile(path, dtype="<u2")
    se9 = ((arr >> 7) & 0x1FF).astype(np.uint16)
    m7 = (arr & 0x7F).astype(np.uint8)
    h_se = np.bincount(se9.astype(np.int64), minlength=512).astype(np.float64)
    h_m = np.bincount(m7.astype(np.int64), minlength=128).astype(np.float64)
    if h_se.sum() > 0: h_se /= h_se.sum()
    if h_m.sum() > 0: h_m /= h_m.sum()
    return np.concatenate([h_se, h_m])


def cluster_training(corpus_dir, k):
    paths = [p.resolve() for p in sorted(Path(corpus_dir).iterdir()) if p.is_symlink() or p.is_file()]
    print(f"  computing features for {len(paths)} files …", flush=True)
    X = np.stack([file_features(str(p)) for p in paths])
    print(f"  K-means K={k} …", flush=True)
    km = KMeans(n_clusters=k, random_state=0, n_init=4).fit(X)
    return paths, km.labels_


def fit_predictor_on_files(paths):
    chunks = []
    total = 0
    cap = 1 << 30  # 1 GB cap per cluster
    for p in paths:
        sz = p.stat().st_size
        if total + sz > cap:
            pairs = max(0, cap - total) // 2
            if pairs > 0:
                chunks.append(np.fromfile(str(p), dtype="<u2", count=pairs))
            break
        chunks.append(np.fromfile(str(p), dtype="<u2"))
        total += sz
    if not chunks:
        return None
    arr = np.concatenate(chunks)
    pred = GorillaBf16Predictor()
    pred.fit(arr)
    return pred


def cmd_train(args):
    out = Path(args.output); out.parent.mkdir(parents=True, exist_ok=True)
    t0 = time.perf_counter()
    paths, labels = cluster_training(args.corpus, args.k)
    cluster_params = []
    counts = []
    for k in range(args.k):
        group = [paths[i] for i in range(len(paths)) if labels[i] == k]
        print(f"  cluster {k}: {len(group)} files", flush=True)
        if not group:
            cluster_params.append(None); counts.append(0); continue
        pred = fit_predictor_on_files(group)
        cluster_params.append(pred.params_dict() if pred else None)
        counts.append(len(group))
    payload = {"k": args.k, "counts": counts, "clusters": cluster_params}
    with open(out, "w") as f:
        json.dump(payload, f)
    sz = out.stat().st_size
    print(f"[mixture.train] wrote {out}  size={sz} bytes  elapsed={time.perf_counter() - t0:.1f}s", flush=True)


def _predictors_from_profile(payload):
    preds = []
    for cp in payload["clusters"]:
        if cp is None:
            preds.append(None)
        else:
            preds.append(GorillaBf16Predictor.from_params(cp))
    return preds


def _pick_cluster(arr_u16, preds):
    """Return index of cluster CDF with lowest expected encoded size, using cross-entropy.

    Cross-entropy H(p_chunk || p_cluster) = -Σ p_chunk[i] * log2(p_cluster[i] + eps)
    Lower = better. We sum cross-entropies for SE9 and M7 streams.
    """
    se9 = ((arr_u16 >> 7) & 0x1FF).astype(np.uint16)
    m7 = (arr_u16 & 0x7F).astype(np.uint8)
    h_se = np.bincount(se9.astype(np.int64), minlength=512).astype(np.float64)
    h_m = np.bincount(m7.astype(np.int64), minlength=128).astype(np.float64)
    if h_se.sum() > 0: h_se /= h_se.sum()
    if h_m.sum() > 0: h_m /= h_m.sum()
    best_i, best_cost = -1, float("inf")
    for i, pred in enumerate(preds):
        if pred is None: continue
        # log probs of each cluster's CDFs
        log_se = np.log2(pred.cdf_se9 + 1e-30)
        log_m = np.log2(pred.cdf_m7 + 1e-30)
        cost = -(h_se * log_se).sum() - (h_m * log_m).sum()
        if cost < best_cost:
            best_cost = cost; best_i = i
    return best_i


def cmd_compress(args):
    with open(args.pretrained) as f:
        payload = json.load(f)
    preds = _predictors_from_profile(payload)
    K = payload["k"]

    bits = np.fromfile(args.input, dtype="<u2")
    n = len(bits)
    all_streams = []
    chunk_tags = []
    for i in range(0, n, _CHUNK_SIZE):
        chunk = bits[i:i + _CHUNK_SIZE]
        ci = i // _CHUNK_SIZE
        idx = _pick_cluster(chunk, preds)
        if idx < 0:
            raise RuntimeError("no clusters available")
        pred = preds[idx]
        streams, _ = pred.encode(chunk)
        chunk_tags.append(idx)
        for s in streams:
            all_streams.append((f"c{ci}_{s.name}", s.codec_id, s.payload))
    header = {
        "n": n,
        "chunk_size": _CHUNK_SIZE,
        "n_chunks": len(chunk_tags),
        "chunk_tags": chunk_tags,
        "k": K,
    }
    container.write(args.output, profile_id=4, header_json=header, streams=all_streams)


def cmd_decompress(args):
    with open(args.pretrained) as f:
        payload = json.load(f)
    preds = _predictors_from_profile(payload)

    profile_id, header, raw_streams = container.read(args.input)
    if profile_id != 4:
        raise ValueError(f"expected bf16_split profile (4), got {profile_id}")
    n = header["n"]; chunk_size = header["chunk_size"]; n_chunks = header["n_chunks"]
    chunk_tags = header["chunk_tags"]
    from collections import defaultdict
    by_chunk = defaultdict(dict)
    for name, codec_id, blob in raw_streams:
        if not name.startswith("c"):
            raise ValueError(f"unexpected stream name {name}")
        idx_str, sname = name[1:].split("_", 1)
        by_chunk[int(idx_str)][sname] = Stream(sname, codec_id, blob)
    parts = []
    for ci in range(n_chunks):
        clen = min(chunk_size, n - ci * chunk_size)
        pred = preds[chunk_tags[ci]]
        parts.append(pred.decode(by_chunk[ci], n=clen))
    np.concatenate(parts).astype("<u2").tofile(args.output)


def main():
    ap = argparse.ArgumentParser()
    sub = ap.add_subparsers(dest="cmd", required=True)
    p = sub.add_parser("train")
    p.add_argument("--corpus", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--k", type=int, default=8)
    p = sub.add_parser("compress")
    p.add_argument("--input", required=True); p.add_argument("--output", required=True)
    p.add_argument("--pretrained", required=True)
    p = sub.add_parser("decompress")
    p.add_argument("--input", required=True); p.add_argument("--output", required=True)
    p.add_argument("--pretrained", required=True)
    args = ap.parse_args()
    {"train": cmd_train, "compress": cmd_compress, "decompress": cmd_decompress}[args.cmd](args)


if __name__ == "__main__":
    main()
