"""Int4 entropy atlas for Q4_K_M GGUF tensors.

For each Q4_K tensor in a corpus, measure:
  - H_nibble   : entropy of unpacked int4 stream (alphabet 16)
  - H_byte     : entropy of packed nibble stream (alphabet 256)
  - H_sign     : entropy of high bit of nibble (alphabet 2)
  - H_magnitude: entropy of low 3 bits of nibble (alphabet 8)
  - H_nibble_given_prev: conditional entropy within-block (31 pairs per 32-elem sub-block)
  - magnitude_concentration: fraction with |nibble - 8| <= 2 (post-centering)
  - scale_byte0_entropy, scale_byte1_entropy: entropy of low/high bytes of fp16 scales
  - adjacent_block_scale_correlation_pearson

Convention for nibble interpretation: GGUF Q4_K nibble values are stored as unsigned
4-bit (0..15) but are dequantized via `scale * (nibble - 8) - min`. The center (8)
is the "zero" — magnitude is |nibble - 8|.

Aggregate output: JSONL with one row per tensor.
"""
import argparse
import json
import os
import sys
import time
from multiprocessing import Pool, cpu_count
from pathlib import Path

import numpy as np

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "01_gguf_parser"))
from gguf_inspect import parse_manifest, extract_tensor_q4k, unpack_nibbles


def shannon(counts):
    total = counts.sum()
    if total == 0: return 0.0
    p = counts.astype(np.float64) / total
    p = p[p > 0]
    return float(-(p * np.log2(p)).sum())


def measure_tensor(gguf_path: str, name: str, source_model: str, tensor_category: str):
    try:
        d, dmin, scales6, _, nibble_blocks = extract_tensor_q4k(gguf_path, name)
        n_blocks = nibble_blocks.shape[0]
        # Unpack all nibbles to a single int4 stream (one byte per nibble, value 0..15)
        nibbles = unpack_nibbles(nibble_blocks)
        n_elements = nibbles.size
        # Histograms
        nibble_hist = np.bincount(nibbles, minlength=16).astype(np.int64)
        H_nibble = shannon(nibble_hist)
        # Packed byte stream — original nibble bytes
        byte_stream = nibble_blocks.reshape(-1)
        byte_hist = np.bincount(byte_stream, minlength=256).astype(np.int64)
        H_byte = shannon(byte_hist)
        # Sign/magnitude split (interpret nibble as signed via center 8)
        # sign = (nibble >= 8)? 1 : 0  (high bit of nibble)
        # magnitude = |nibble - 8|, in {0..8}, but high values are rare so essentially 0..7
        signs = (nibbles >> 3) & 1
        sign_hist = np.bincount(signs, minlength=2).astype(np.int64)
        H_sign = shannon(sign_hist)
        # Magnitude relative to 8 (center)
        centered_mag = np.abs(nibbles.astype(np.int16) - 8).astype(np.uint8)
        mag_hist = np.bincount(centered_mag, minlength=9).astype(np.int64)
        H_magnitude = shannon(mag_hist)
        # Concentration: fraction with |nibble-8| <= 2 (near center / small magnitude)
        magnitude_concentration = float((centered_mag <= 2).mean())
        # Within-block adjacent conditional entropy: 32-elem sub-blocks
        # GGUF Q4_K packs 256 nibbles per super-block; for sub-blocks of 32, use 8 per super-block.
        # Use a simple within-tensor adjacency: take all adjacent (nibble[i], nibble[i+1]) for i not at block boundaries.
        # Build joint counts on a random sample of pairs for speed.
        sample_size = min(n_elements - 1, 2_000_000)
        rng = np.random.RandomState(0)
        if n_elements > sample_size + 1:
            idx = rng.choice(n_elements - 1, size=sample_size, replace=False)
        else:
            idx = np.arange(n_elements - 1)
        prev = nibbles[idx].astype(np.int64)
        curr = nibbles[idx + 1].astype(np.int64)
        joint = np.bincount(prev * 16 + curr, minlength=256).reshape(16, 16).astype(np.float64)
        # H(curr | prev) = sum_prev P(prev) * H(curr | prev=prev)
        p_prev = joint.sum(axis=1) / joint.sum()
        H_cond = 0.0
        for s in range(16):
            row_total = joint[s].sum()
            if row_total > 0 and p_prev[s] > 0:
                cond_p = joint[s] / row_total
                H_cond += p_prev[s] * shannon(np.round(cond_p * 1e6).astype(np.int64))
        # Above mixes integer and float; cleaner version: use direct float entropy
        H_cond = 0.0
        for s in range(16):
            row_total = joint[s].sum()
            if row_total > 0 and p_prev[s] > 0:
                cond_p = joint[s] / row_total
                cond_p = cond_p[cond_p > 0]
                H_cond += p_prev[s] * float(-(cond_p * np.log2(cond_p)).sum())
        # Scale stream byte entropies
        scale_bytes = d.view(np.uint8).reshape(-1, 2)
        sb0 = scale_bytes[:, 0]
        sb1 = scale_bytes[:, 1]
        H_scale_byte0 = shannon(np.bincount(sb0.astype(np.int64), minlength=256))
        H_scale_byte1 = shannon(np.bincount(sb1.astype(np.int64), minlength=256))
        # Adjacent block scale correlation (pearson on consecutive fp16 scales, viewed as float32)
        d_f32 = d.astype(np.float32)
        if d_f32.size >= 2:
            adj_corr = float(np.corrcoef(d_f32[:-1], d_f32[1:])[0, 1])
        else:
            adj_corr = 0.0
        return {
            "tensor_name": name,
            "model": source_model,
            "tensor_category": tensor_category,
            "n_elements": int(n_elements),
            "n_blocks": int(n_blocks),
            "H_nibble": H_nibble,
            "H_byte": H_byte,
            "H_sign": H_sign,
            "H_magnitude": H_magnitude,
            "H_nibble_given_prev": H_cond,
            "magnitude_concentration": magnitude_concentration,
            "scale_byte0_entropy": H_scale_byte0,
            "scale_byte1_entropy": H_scale_byte1,
            "adjacent_block_scale_correlation_pearson": adj_corr,
        }
    except Exception as e:
        return {"tensor_name": name, "model": source_model, "error": str(e)}


def _process(args_tuple):
    gguf_path, name, source_model, category = args_tuple
    return measure_tensor(gguf_path, name, source_model, category)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--models", nargs="+", help="GGUF file paths")
    ap.add_argument("--output", required=True)
    ap.add_argument("--log", default=None)
    ap.add_argument("--workers", type=int, default=24)
    ap.add_argument("--self-test", action="store_true")
    args = ap.parse_args()

    if args.self_test:
        # Synthetic test
        # Build a fake nibble stream with strong magnitude concentration
        rng = np.random.RandomState(0)
        # Center at 8 with Laplace concentration
        nibbles = (8 + rng.laplace(scale=1.5, size=10000).astype(np.int32)).clip(0, 15).astype(np.uint8)
        nb = np.bincount(nibbles, minlength=16)
        H = shannon(nb)
        print(f"synthetic Laplacian nibbles: H_nibble = {H:.3f} (uniform would be 4.0)")
        assert H < 3.8, "fake should be sub-uniform"
        print("OK: atlas_int4 self-test passed")
        return

    rows = []
    if args.log:
        log_f = open(args.log, "w")
    else:
        log_f = sys.stderr
    for gguf_path in args.models:
        if not os.path.exists(gguf_path):
            print(f"missing: {gguf_path}", file=log_f); continue
        source_model = os.path.basename(gguf_path).replace(".gguf", "")
        print(f"== {source_model} ==", file=log_f, flush=True)
        manifest = parse_manifest(gguf_path)
        q4k = [r for r in manifest if r["dtype"] == "Q4_K"]
        # Only sufficiently large tensors (skip tiny norm/bias)
        q4k = [r for r in q4k if r["n_elements"] >= 65536]
        print(f"  {len(q4k)} Q4_K tensors (>= 65K elements)", file=log_f, flush=True)
        tasks = [(gguf_path, r["name"], source_model, r["tensor_category"]) for r in q4k]
        t0 = time.perf_counter()
        with Pool(args.workers) as pool:
            for i, row in enumerate(pool.imap_unordered(_process, tasks, chunksize=1)):
                rows.append(row)
                if (i + 1) % 50 == 0 or i + 1 == len(tasks):
                    print(f"  [{i+1}/{len(tasks)}] {time.perf_counter() - t0:.0f}s", file=log_f, flush=True)
    with open(args.output, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    ok = [r for r in rows if "error" not in r]
    if ok:
        meds = {
            "H_nibble": float(np.median([r["H_nibble"] for r in ok])),
            "H_byte": float(np.median([r["H_byte"] for r in ok])),
            "H_sign": float(np.median([r["H_sign"] for r in ok])),
            "H_magnitude": float(np.median([r["H_magnitude"] for r in ok])),
            "H_nibble_given_prev": float(np.median([r["H_nibble_given_prev"] for r in ok])),
            "magnitude_concentration": float(np.median([r["magnitude_concentration"] for r in ok])),
            "scale_byte0_entropy": float(np.median([r["scale_byte0_entropy"] for r in ok])),
            "scale_byte1_entropy": float(np.median([r["scale_byte1_entropy"] for r in ok])),
            "adjacent_block_scale_correlation_pearson": float(np.median([r["adjacent_block_scale_correlation_pearson"] for r in ok])),
        }
        print(f"\nAtlas medians (n={len(ok)}):")
        for k, v in meds.items():
            print(f"  {k}: {v:.4f}")


if __name__ == "__main__":
    main()
