"""bf16_split pretrained wrapper.

Train a single global GorillaBf16Predictor on the bf16 training corpus,
save CDFs to JSON, and reuse at compress time (no per-chunk fit).

For honest accounting: do NOT embed pretrained_params in the container header.
The external JSON is the only profile artifact. Decompress requires --pretrained.

(This is the fix the v2 audit recommended in exp 1.)
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path

import numpy as np

sys.path.insert(0, "/mnt/data/scape")
from scape import container
from scape.predictors.gorilla_bf16 import GorillaBf16Predictor

_CHUNK_SIZE = 4 * 1024 * 1024 // 2  # u16 elements per chunk


def _read_u16_concat(corpus_dir, cap_bytes=4 << 30):
    chunks = []
    total = 0
    for p in sorted(Path(corpus_dir).iterdir()):
        real = p.resolve()
        sz = real.stat().st_size
        if total + sz > cap_bytes:
            take = max(0, cap_bytes - total)
            if take >= 2:
                pairs = take // 2
                arr = np.fromfile(str(real), dtype="<u2", count=pairs)
                chunks.append(arr); total += arr.nbytes
            break
        arr = np.fromfile(str(real), dtype="<u2")
        chunks.append(arr); total += arr.nbytes
    if not chunks:
        return np.zeros(0, dtype=np.uint16)
    return np.concatenate(chunks)


def cmd_train(args):
    corpus = Path(args.corpus); out = Path(args.output)
    out.parent.mkdir(parents=True, exist_ok=True)
    t0 = time.perf_counter()
    print(f"[bf16_pretrained.train] reading {corpus} …", flush=True)
    bits = _read_u16_concat(corpus, cap_bytes=args.cap_bytes)
    print(f"  read {bits.nbytes / (1 << 20):.0f} MiB ({bits.size:,} u16)", flush=True)
    print(f"[bf16_pretrained.train] fitting GorillaBf16Predictor …", flush=True)
    pred = GorillaBf16Predictor(); pred.fit(bits)
    params = pred.params_dict()
    elapsed = time.perf_counter() - t0
    with open(out, "w") as f:
        json.dump(params, f)
    sz = out.stat().st_size
    print(f"[bf16_pretrained.train] wrote {out}  size={sz} bytes  elapsed={elapsed:.1f}s", flush=True)


def cmd_compress(args):
    with open(args.pretrained) as f:
        params = json.load(f)
    pred = GorillaBf16Predictor.from_params(params)

    bits = np.fromfile(args.input, dtype="<u2")
    n = len(bits)
    all_streams = []
    chunk_markers = []
    for i in range(0, n, _CHUNK_SIZE):
        chunk = bits[i:i + _CHUNK_SIZE]
        streams, _ = pred.encode(chunk)
        chunk_markers.append({"pretrained": True})
        for s in streams:
            all_streams.append((f"c{i // _CHUNK_SIZE}_{s.name}", s.codec_id, s.payload))
    # CRITICAL (audit fix): do NOT embed params in header_json
    header = {
        "n": n,
        "chunk_size": _CHUNK_SIZE,
        "n_chunks": len(chunk_markers),
        "chunk_params": chunk_markers,
    }
    container.write(args.output, profile_id=4, header_json=header, streams=all_streams)


def cmd_decompress(args):
    with open(args.pretrained) as f:
        params = json.load(f)
    pred = GorillaBf16Predictor.from_params(params)

    profile_id, header, raw_streams = container.read(args.input)
    if profile_id != 4:
        raise ValueError(f"expected bf16_split profile (4), got {profile_id}")
    n = header["n"]; chunk_size = header["chunk_size"]; n_chunks = header["n_chunks"]
    from collections import defaultdict
    from scape.predictors.base import Stream
    by_chunk = defaultdict(dict)
    for name, codec_id, payload in raw_streams:
        if not name.startswith("c"):
            raise ValueError(f"unexpected stream name {name}")
        idx_str, sname = name[1:].split("_", 1)
        by_chunk[int(idx_str)][sname] = Stream(sname, codec_id, payload)
    parts = []
    for idx in range(n_chunks):
        clen = min(chunk_size, n - idx * chunk_size)
        parts.append(pred.decode(by_chunk[idx], n=clen))
    np.concatenate(parts).astype("<u2").tofile(args.output)


def main():
    ap = argparse.ArgumentParser()
    sub = ap.add_subparsers(dest="cmd", required=True)
    p = sub.add_parser("train")
    p.add_argument("--corpus", required=True)
    p.add_argument("--output", required=True)
    p.add_argument("--cap-bytes", type=int, default=4 * (1 << 30))
    p = sub.add_parser("compress")
    p.add_argument("--input", required=True); p.add_argument("--output", required=True)
    p.add_argument("--pretrained", required=True)
    p = sub.add_parser("decompress")
    p.add_argument("--input", required=True); p.add_argument("--output", required=True)
    p.add_argument("--pretrained", required=True)
    args = ap.parse_args()
    {"train": cmd_train, "compress": cmd_compress, "decompress": cmd_decompress}[args.cmd](args)


if __name__ == "__main__":
    main()
