"""Method Q.A — Sign + magnitude bit-plane decomposition of int4 nibbles.

For each Q4_K tensor:
  1. Read super-blocks: (d, dmin, scales6, nibbles)
  2. Unpack nibbles to int4 stream
  3. Split into sign bit (high bit of nibble) and magnitude (low 3 bits)
  4. Compress sign + magnitude streams independently with zstd-19
  5. Compress scale stream (super-block d + dmin) with zstd-19 (byte-grouped)
  6. Compress scales6 (sub-block scale/min packed bytes) with zstd-19

Container layout:
  magic 'QA01' | n_blocks u32 | d_bytes u32 | dmin_bytes u32 | scales6_bytes u32 |
  sign_bytes u32 | mag_bytes u32 | d_blob | dmin_blob | scales6_blob | sign_blob | mag_blob
"""
import argparse
import hashlib
import json
import os
import struct
import sys
from pathlib import Path

import numpy as np
import zstandard as zstd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "01_gguf_parser"))
from gguf_inspect import extract_tensor_q4k, repack_nibbles, unpack_nibbles


_MAGIC = b"QA01"


def _bytegrouped_fp16(arr):
    """Split fp16 array into low-byte stream + high-byte stream concatenated."""
    raw = arr.tobytes()
    a = np.frombuffer(raw, dtype=np.uint8).reshape(-1, 2)
    lo = a[:, 0].tobytes()
    hi = a[:, 1].tobytes()
    return lo + hi


def _ungroup_fp16(blob, n):
    """Inverse of _bytegrouped_fp16: blob is 2*n bytes; return fp16 array of length n."""
    lo = np.frombuffer(blob[:n], dtype=np.uint8)
    hi = np.frombuffer(blob[n:2 * n], dtype=np.uint8)
    out = np.empty(2 * n, dtype=np.uint8)
    out[0::2] = lo
    out[1::2] = hi
    return out.tobytes()


def compress_q4k_tensor_qa(d: np.ndarray, dmin: np.ndarray, scales6: np.ndarray,
                            nibble_blocks: np.ndarray) -> bytes:
    """Compress per-tensor. Returns bytes of the Q.A container."""
    n_blocks = nibble_blocks.shape[0]
    # Unpack nibbles to int4 stream
    nibbles = unpack_nibbles(nibble_blocks)  # uint8 in [0..15]
    # Sign bit (high bit) + magnitude (low 3 bits)
    sign_stream = ((nibbles >> 3) & 1).astype(np.uint8)
    mag_stream = (nibbles & 0x07).astype(np.uint8)
    # Pack sign stream 8 bits per byte
    n = sign_stream.size
    sign_packed = np.packbits(sign_stream, bitorder="little")
    # Pack magnitude stream: 3 bits per element. Use uint8 (1 byte per element, wasteful)
    # for simplicity — zstd will compress the redundancy.
    mag_packed = mag_stream.tobytes()
    # fp16 d, dmin: byte-grouped + zstd
    d_grouped = _bytegrouped_fp16(d.astype("<f2"))
    dmin_grouped = _bytegrouped_fp16(dmin.astype("<f2"))
    # scales6 raw bytes
    scales6_raw = scales6.tobytes()
    # Compress each with zstd-19
    cctx = zstd.ZstdCompressor(level=19)
    d_blob = cctx.compress(d_grouped)
    dmin_blob = cctx.compress(dmin_grouped)
    scales6_blob = cctx.compress(scales6_raw)
    sign_blob = cctx.compress(sign_packed.tobytes())
    mag_blob = cctx.compress(mag_packed)
    # Container
    header = struct.pack("<4sIIIIII",
                          _MAGIC, n_blocks,
                          len(d_blob), len(dmin_blob), len(scales6_blob),
                          len(sign_blob), len(mag_blob))
    return header + d_blob + dmin_blob + scales6_blob + sign_blob + mag_blob


def decompress_q4k_tensor_qa(blob: bytes):
    """Returns reconstructed (d, dmin, scales6, nibble_blocks)."""
    magic, n_blocks, d_len, dmin_len, s6_len, sign_len, mag_len = struct.unpack("<4sIIIIII", blob[:28])
    if magic != _MAGIC:
        raise ValueError(f"bad magic: {magic}")
    pos = 28
    dctx = zstd.ZstdDecompressor()
    d_raw = dctx.decompress(blob[pos:pos + d_len]); pos += d_len
    dmin_raw = dctx.decompress(blob[pos:pos + dmin_len]); pos += dmin_len
    s6_raw = dctx.decompress(blob[pos:pos + s6_len]); pos += s6_len
    sign_raw = dctx.decompress(blob[pos:pos + sign_len]); pos += sign_len
    mag_raw = dctx.decompress(blob[pos:pos + mag_len]); pos += mag_len
    # Reconstruct
    d_bytes = _ungroup_fp16(d_raw, n_blocks)
    dmin_bytes = _ungroup_fp16(dmin_raw, n_blocks)
    d = np.frombuffer(d_bytes, dtype="<f2").copy()
    dmin = np.frombuffer(dmin_bytes, dtype="<f2").copy()
    scales6 = np.frombuffer(s6_raw, dtype=np.uint8).reshape(n_blocks, 12).copy()
    n_nibbles = n_blocks * 256
    sign_stream = np.unpackbits(np.frombuffer(sign_raw, dtype=np.uint8), bitorder="little")[:n_nibbles].astype(np.uint8)
    mag_stream = np.frombuffer(mag_raw, dtype=np.uint8)
    nibbles = ((sign_stream & 1) << 3) | (mag_stream & 0x07)
    nibble_blocks = repack_nibbles(nibbles).reshape(n_blocks, 128)
    return d, dmin, scales6, nibble_blocks


def self_test():
    rng = np.random.RandomState(42)
    n_blocks = 64
    d = (rng.uniform(0.01, 0.5, size=n_blocks).astype(np.float16))
    dmin = (rng.uniform(-0.5, 0.5, size=n_blocks).astype(np.float16))
    scales6 = rng.randint(0, 256, size=(n_blocks, 12)).astype(np.uint8)
    # Random nibbles with concentration
    nibbles = (8 + rng.laplace(scale=1.5, size=n_blocks * 256).astype(np.int32)).clip(0, 15).astype(np.uint8)
    nibble_blocks = repack_nibbles(nibbles).reshape(n_blocks, 128)
    blob = compress_q4k_tensor_qa(d, dmin, scales6, nibble_blocks)
    d2, dmin2, s62, nb2 = decompress_q4k_tensor_qa(blob)
    assert np.array_equal(scales6, s62), "scales6 mismatch"
    assert np.array_equal(nibble_blocks, nb2), "nibble_blocks mismatch"
    # Compare d and dmin as raw bytes (fp16 == fp16 should match exactly)
    assert d.tobytes() == d2.tobytes(), "d mismatch"
    assert dmin.tobytes() == dmin2.tobytes(), "dmin mismatch"
    orig_bytes = n_blocks * 144
    comp_bytes = len(blob)
    print(f"OK: qa_residual self-test passed. n_blocks={n_blocks}  orig={orig_bytes}B  compressed={comp_bytes}B  ratio={orig_bytes/comp_bytes:.3f}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--self-test", action="store_true")
    ap.add_argument("--input", help="GGUF input")
    ap.add_argument("--tensor", help="tensor name to bench")
    args = ap.parse_args()
    if args.self_test:
        self_test()
        return
    if args.input and args.tensor:
        d, dmin, scales6, _, nibble_blocks = extract_tensor_q4k(args.input, args.tensor)
        orig = nibble_blocks.shape[0] * 144
        blob = compress_q4k_tensor_qa(d, dmin, scales6, nibble_blocks)
        print(f"Q.A on {args.tensor}: orig={orig}, comp={len(blob)}, ratio={orig/len(blob):.4f}")
        # Verify roundtrip
        d2, dmin2, s62, nb2 = decompress_q4k_tensor_qa(blob)
        assert d.tobytes() == d2.tobytes() and dmin.tobytes() == dmin2.tobytes()
        assert np.array_equal(scales6, s62) and np.array_equal(nibble_blocks, nb2)
        print("  roundtrip OK")


if __name__ == "__main__":
    main()
