"""GGUF parser + manifest extractor.

For Q4_K_M quantization, GGUF stores 256-element super-blocks. Each super-block contains:
  - 1 fp16 super-block scale (`d`)
  - 1 fp16 super-block min (`dmin`)
  - 6-bit packed sub-block scales (12 bytes for 8 sub-blocks of 32 elements)
  - 6-bit packed sub-block mins (12 bytes)
  - 128 bytes of 4-bit nibbles (256 elements × 4 bits)

Total per super-block: 4 + 12 + 12 + 128 = 144 bytes = 4.5 bits/element.

Reference: https://github.com/ggerganov/llama.cpp/blob/master/ggml/src/ggml-quants.h `block_q4_K`.

This parser:
 - Reads GGUF header + tensor metadata via the gguf python module.
 - Computes byte offsets to each tensor.
 - For Q4_K (== Q4_K_M when configured), extracts the raw 144-byte super-blocks.
 - Splits per-super-block into (nibbles, scales, mins, d, dmin).

The nibble stream is the 4-bit indices (16 alphabet symbols). To "unpack" them
as a uint8 array, we split each byte into low/high nibbles.
"""
from __future__ import annotations

import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Iterator

import numpy as np

try:
    import gguf
    from gguf import GGUFReader, GGMLQuantizationType
except ImportError:
    print("ERROR: pip install gguf", file=sys.stderr)
    raise

# Q4_K super-block layout
_Q4K_BLOCK_BYTES = 144         # bytes per super-block of 256 elements
_Q4K_BLOCK_ELEMENTS = 256      # elements per super-block
_Q4K_SUBBLOCK_ELEMENTS = 32    # elements per sub-block
_Q4K_NIBBLE_BYTES_PER_BLOCK = 128
_Q4K_HEADER_BYTES = 4 + 12 + 12   # d(2) + dmin(2) + 12 + 12 = 28 (BUT really 2+2+12+12=28; nibbles start at byte 28)


def tensor_category(name: str) -> str:
    n = name.lower()
    if "token_embd" in n or "embed_tokens" in n: return "embedding"
    if "output." in n or "lm_head" in n: return "lm_head"
    if "attn" in n or "wq" in n or "wk" in n or "wv" in n or "wo" in n: return "attention"
    if "ffn" in n or "mlp" in n or "feed" in n: return "mlp"
    if "norm" in n: return "norm"
    return "other"


def parse_manifest(gguf_path: str) -> list[dict]:
    """Walk all tensors in the GGUF and return manifest rows."""
    reader = GGUFReader(gguf_path, "r")
    rows = []
    for t in reader.tensors:
        name = t.name
        shape = [int(s) for s in t.shape]
        n_elements = int(np.prod(shape)) if shape else 0
        qtype_id = int(t.tensor_type)
        qtype_name = GGMLQuantizationType(qtype_id).name
        # data_offset in t.data_offset (absolute file offset)
        raw_size = int(t.n_bytes)
        n_blocks = None
        scales_bytes = None
        ints_bytes = None
        block_size_elements = None
        if qtype_name == "Q4_K":
            block_size_elements = _Q4K_BLOCK_ELEMENTS
            n_blocks = n_elements // _Q4K_BLOCK_ELEMENTS
            ints_bytes = n_blocks * _Q4K_NIBBLE_BYTES_PER_BLOCK
            scales_bytes = n_blocks * (_Q4K_BLOCK_BYTES - _Q4K_NIBBLE_BYTES_PER_BLOCK)
        rows.append({
            "name": name,
            "shape": shape,
            "dtype": qtype_name,
            "n_elements": n_elements,
            "n_blocks": n_blocks,
            "block_size_elements": block_size_elements,
            "data_offset": int(t.data_offset),
            "raw_size_bytes": raw_size,
            "ints_size_bytes": ints_bytes,
            "scales_size_bytes": scales_bytes,
            "tensor_category": tensor_category(name),
        })
    return rows


def extract_tensor_q4k(gguf_path: str, name: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Extract a Q4_K tensor.

    Returns a tuple of arrays, all with shape determined by n_blocks:
      d        : (n_blocks,) fp16     super-block scale
      dmin     : (n_blocks,) fp16     super-block min
      scales6  : (n_blocks, 12) uint8 raw 6-bit packed sub-block scales (12 bytes)
      mins6    : (n_blocks, 12) uint8 raw 6-bit packed sub-block mins (12 bytes)
      nibbles  : (n_blocks, 128) uint8 raw 4-bit packed nibbles (128 bytes -> 256 nibbles)

    The 4-bit nibbles are stored: byte b => low nibble = (b & 0xF), high nibble = (b >> 4).
    """
    reader = GGUFReader(gguf_path, "r")
    target = None
    for t in reader.tensors:
        if t.name == name:
            target = t
            break
    if target is None:
        raise KeyError(f"tensor {name!r} not found in {gguf_path}")
    qtype = GGMLQuantizationType(int(target.tensor_type)).name
    if qtype != "Q4_K":
        raise ValueError(f"tensor {name} is {qtype}, expected Q4_K")
    n_elements = int(np.prod([int(s) for s in target.shape]))
    n_blocks = n_elements // _Q4K_BLOCK_ELEMENTS
    expected_bytes = n_blocks * _Q4K_BLOCK_BYTES
    if int(target.n_bytes) != expected_bytes:
        raise RuntimeError(f"tensor {name}: n_bytes={target.n_bytes} != n_blocks*{_Q4K_BLOCK_BYTES}={expected_bytes}")
    with open(gguf_path, "rb") as f:
        f.seek(int(target.data_offset))
        raw = f.read(expected_bytes)
    if len(raw) != expected_bytes:
        raise RuntimeError(f"short read for {name}: got {len(raw)}, expected {expected_bytes}")
    # Reshape into (n_blocks, 144)
    blocks = np.frombuffer(raw, dtype=np.uint8).reshape(n_blocks, _Q4K_BLOCK_BYTES)
    # Parse layout: d (2) + dmin (2) + scales (12) + mins (12) + nibbles (128)
    # Wait: actual block_q4_K layout in ggml: half d, half dmin, uint8_t scales[12], uint8_t qs[128]
    # But the scales/mins are packed: 6 bits scale + 6 bits min for each of 8 sub-blocks = 12 bytes.
    # In code: scales[0..11] holds both. So we have:
    #   bytes 0-1: d (fp16 LE)
    #   bytes 2-3: dmin (fp16 LE)
    #   bytes 4-15: 12 bytes of packed (scale_6bit, min_6bit) for 8 sub-blocks
    #   bytes 16-143: 128 bytes of nibbles
    d = blocks[:, 0:2].copy().view(np.uint16).view(np.float16).reshape(-1)
    dmin = blocks[:, 2:4].copy().view(np.uint16).view(np.float16).reshape(-1)
    scales_mins = blocks[:, 4:16].copy()   # (n_blocks, 12) packed 6-bit scale + 6-bit min for 8 sub-blocks
    nibbles = blocks[:, 16:144].copy()     # (n_blocks, 128) raw bytes (each byte = 2 nibbles)
    # Note: scales_mins encoding is intricate (6-bit values packed); we return it raw.
    return d, dmin, scales_mins, scales_mins, nibbles  # mins6 also in scales_mins (we don't split here)


def unpack_nibbles(nibble_bytes: np.ndarray) -> np.ndarray:
    """Unpack a uint8 array of N bytes into a uint8 array of 2N nibbles in [0..15].

    Convention: low nibble first, then high nibble (matches llama.cpp ordering).
    """
    flat = nibble_bytes.reshape(-1).astype(np.uint8)
    out = np.empty(flat.size * 2, dtype=np.uint8)
    out[0::2] = flat & 0x0F
    out[1::2] = (flat >> 4) & 0x0F
    return out


def repack_nibbles(nibbles_int4: np.ndarray) -> np.ndarray:
    """Inverse of unpack_nibbles: pack N nibbles in [0..15] into N/2 bytes."""
    if nibbles_int4.size % 2 != 0:
        raise ValueError("nibble count must be even")
    low = nibbles_int4[0::2].astype(np.uint8) & 0x0F
    high = (nibbles_int4[1::2].astype(np.uint8) & 0x0F) << 4
    return (low | high).astype(np.uint8)


def cmd_inspect(args):
    rows = parse_manifest(args.input)
    if args.output:
        with open(args.output, "w") as f:
            for r in rows:
                f.write(json.dumps(r) + "\n")
    if args.show:
        for r in rows[:args.show]:
            print(json.dumps(r, indent=2))
    types = {}
    for r in rows:
        types[r["dtype"]] = types.get(r["dtype"], 0) + 1
    cat = {}
    for r in rows:
        cat[r["tensor_category"]] = cat.get(r["tensor_category"], 0) + 1
    print(f"\n{args.input}: {len(rows)} tensors, dtypes={types}, categories={cat}")


def cmd_self_test(args):
    # Tests that pack/unpack roundtrip is correct.
    rng = np.random.RandomState(42)
    nibbles = rng.randint(0, 16, size=256).astype(np.uint8)
    packed = repack_nibbles(nibbles)
    assert packed.size == 128, f"expected 128 bytes, got {packed.size}"
    unpacked = unpack_nibbles(packed)
    assert np.array_equal(unpacked, nibbles), "roundtrip failure"
    # If a GGUF file path is given, also try extracting from it.
    if args.input and os.path.exists(args.input):
        rows = parse_manifest(args.input)
        q4k_rows = [r for r in rows if r["dtype"] == "Q4_K"]
        if q4k_rows:
            tname = q4k_rows[0]["name"]
            d, dmin, scales6, mins6, nibbles = extract_tensor_q4k(args.input, tname)
            # Verify byte content matches what's at the offset in the file
            with open(args.input, "rb") as f:
                f.seek(q4k_rows[0]["data_offset"])
                raw_disk = f.read(q4k_rows[0]["raw_size_bytes"])
            n_blocks = q4k_rows[0]["n_blocks"]
            assert d.shape == (n_blocks,), f"d shape {d.shape}"
            assert nibbles.shape == (n_blocks, 128), f"nibbles shape {nibbles.shape}"
            # Reconstruct the raw bytes from our parsed pieces
            reconstructed = np.zeros((n_blocks, 144), dtype=np.uint8)
            reconstructed[:, 0:2] = d.view(np.uint8).reshape(-1, 2)
            reconstructed[:, 2:4] = dmin.view(np.uint8).reshape(-1, 2)
            reconstructed[:, 4:16] = scales6
            reconstructed[:, 16:144] = nibbles
            assert reconstructed.tobytes() == raw_disk, "Q4_K extraction byte-mismatch"
            print(f"[Q4_K] extracted {tname}: n_blocks={n_blocks}, roundtrip OK ({len(raw_disk)} bytes)")
    print("OK: gguf_inspect self-test passed")
    sys.exit(0)


def main():
    ap = argparse.ArgumentParser()
    sub = ap.add_subparsers(dest="cmd", required=True)
    p = sub.add_parser("inspect")
    p.add_argument("--input", required=True)
    p.add_argument("--output", help="JSONL manifest output path")
    p.add_argument("--show", type=int, default=0, help="print first N entries")
    p = sub.add_parser("self-test")
    p.add_argument("--input", default="", help="optional GGUF path to test extraction on")
    args = ap.parse_args()
    if args.cmd == "inspect":
        cmd_inspect(args)
    elif args.cmd == "self-test":
        cmd_self_test(args)


if __name__ == "__main__":
    main()
