"""
NodeMind Multimodal Benchmark — Interactive Query Demo
=======================================================
Lets you type any text query and see side-by-side results from:
  • NodeMind NM-256  (128× compression, integer binary Hamming search)
  • NodeMind NM-512  ( 64× compression, integer binary Hamming search)
  • NodeMind NM-1024 ( 32× compression, integer binary Hamming search)
  • Gemini RAG       (float32 HNSW cosine search — the baseline)

All 4 search over the same 200-item multimodal corpus
(text, images, audio, tables, code).

Requirements
------------
    pip install flagembedding hnswlib numpy

You also need:
    • The BGE-Visualized-M3 model weights: bge-visualized-m3.pth
      Download from: https://huggingface.co/BAAI/BGE-Visualized (M3 variant)
      Place it in the same folder as this script, or set BGE_WEIGHTS env var.

    • (Optional) A Gemini API key to run the RAG baseline.
      Set environment variable: GEMINI_API_KEY=your_key_here
      Without a key, only NodeMind results are shown.

Usage
-----
    python query_demo.py
    python query_demo.py --query "dog barking sound" --top_k 5
    python query_demo.py --no-rag   # skip Gemini RAG (no API key needed)

The NodeMind search method is protected under AU 2026901656 / AU 2026901657.
"""

import argparse
import importlib.util
import json
import os
import pickle
import sys
import time
from pathlib import Path

import numpy as np

# ─────────────────────────────────────────────────────────────────────────────
# PATHS
# ─────────────────────────────────────────────────────────────────────────────

HERE        = Path(__file__).parent.resolve()
CORPUS_DIR  = HERE / "corpus"
INDEX_DIR   = HERE / "index"

NM_INDEX_PATH  = INDEX_DIR / "nodemind_index.pkl"
RAG_INDEX_PATH = INDEX_DIR / "gemini_rag_index.pkl"
HNSW_PATH      = INDEX_DIR / "gemini_hnsw.bin"

BGE_WEIGHTS = os.environ.get("BGE_WEIGHTS", str(HERE / "bge-visualized-m3.pth"))

GEMINI_MODEL = "gemini-embedding-2"
GEMINI_DIM   = 1024

# ─────────────────────────────────────────────────────────────────────────────
# NODEMIND SEARCH — loaded from compiled module
# The search and binarisation method is protected under AU 2026901656 /
# AU 2026901657. The index is self-contained and works without reading
# the patent.
# ─────────────────────────────────────────────────────────────────────────────

def _load_nm_search():
    pyc = HERE / "_nm_core.pyc"
    if not pyc.exists():
        print("[ERROR] Missing _nm_core.pyc — please re-download the full package.")
        sys.exit(1)
    spec = importlib.util.spec_from_file_location("_nm_core", pyc)
    mod  = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod.nm_search

nm_search = _load_nm_search()

# ─────────────────────────────────────────────────────────────────────────────
# HELPERS — file path resolution
# ─────────────────────────────────────────────────────────────────────────────

def resolve_path(stored_path: str) -> str:
    p = Path(stored_path)
    parts = p.parts
    for anchor in ("images", "audio"):
        if anchor in parts:
            idx = parts.index(anchor)
            rel = Path(*parts[idx:])
            return str(CORPUS_DIR / rel)
    return str(CORPUS_DIR / p.name)


def item_display(item: dict) -> str:
    mod = item.get("modality", "?")
    iid = item.get("item_id", "?")
    if mod == "text":
        snippet = item.get("content", "")[:80]
        return f"[TEXT  {iid}] {snippet}..."
    elif mod == "image":
        cap   = item.get("caption", item.get("label", ""))
        fname = Path(resolve_path(item.get("content", ""))).name
        return f"[IMAGE {iid}] {fname} — {cap}"
    elif mod == "audio":
        lt    = item.get("label_text", item.get("label", ""))
        fname = Path(resolve_path(item.get("content", ""))).name
        return f"[AUDIO {iid}] {fname} — {lt}"
    elif mod == "table":
        snippet = item.get("content", "")[:80].replace("\n", " ")
        return f"[TABLE {iid}] {snippet}..."
    elif mod == "code":
        snippet = item.get("content", "")[:80].replace("\n", " ")
        return f"[CODE  {iid}] {snippet}..."
    else:
        return f"[{mod.upper():<5} {iid}] {str(item)[:80]}"


# ─────────────────────────────────────────────────────────────────────────────
# EMBEDDING — BGE-Visualized-M3
# ─────────────────────────────────────────────────────────────────────────────

_bge_model = None

def load_bge():
    global _bge_model
    if _bge_model is not None:
        return _bge_model

    if not Path(BGE_WEIGHTS).exists():
        print(f"\n[ERROR] BGE weights not found at: {BGE_WEIGHTS}")
        print("  Download from: https://huggingface.co/BAAI/BGE-Visualized (M3 variant)")
        print("  Place bge-visualized-m3.pth in the same folder as this script,")
        print("  or set: export BGE_WEIGHTS=/path/to/bge-visualized-m3.pth")
        sys.exit(1)

    print("  Loading BGE-Visualized-M3 model (this takes ~10s)...")
    try:
        from FlagEmbedding.visual import FlagVisualRetrieval
        _bge_model = FlagVisualRetrieval(
            model_name_or_path="BAAI/bge-m3",
            model_weight=BGE_WEIGHTS,
        )
    except Exception:
        try:
            from visual_bge.modeling import Visualized_BGE
            _bge_model = Visualized_BGE(
                model_name_or_path="BAAI/bge-m3",
                model_weight=BGE_WEIGHTS,
            )
        except ImportError:
            print("\n[ERROR] BGE-Visualized not installed.")
            print("  pip install flagembedding")
            print("  or: pip install -e ./FlagEmbedding/research/visual_bge/")
            sys.exit(1)
    print("  BGE model loaded.\n")
    return _bge_model


def embed_text_bge(text: str) -> np.ndarray:
    model = load_bge()
    vec = model.encode(text=text)
    if hasattr(vec, "cpu"):
        vec = vec.cpu().numpy()
    vec = np.array(vec, dtype=np.float32).flatten()
    norm = np.linalg.norm(vec)
    return vec / norm if norm > 0 else vec


# ─────────────────────────────────────────────────────────────────────────────
# EMBEDDING — Gemini (RAG baseline)
# ─────────────────────────────────────────────────────────────────────────────

_gemini_client = None

def load_gemini(api_key: str):
    global _gemini_client
    if _gemini_client is not None:
        return _gemini_client
    try:
        from google import genai as gai
        _gemini_client = gai.Client(api_key=api_key)
        print("  Gemini client loaded.\n")
        return _gemini_client
    except ImportError:
        print("[WARN] google-genai not installed. pip install google-genai")
        return None


def embed_text_gemini(text: str, client) -> np.ndarray:
    from google.genai import types as gtypes
    for attempt in range(4):
        try:
            result = client.models.embed_content(
                model=GEMINI_MODEL,
                contents=text,
                config=gtypes.EmbedContentConfig(output_dimensionality=GEMINI_DIM),
            )
            vec = np.array(result.embeddings[0].values, dtype=np.float32)
            norm = np.linalg.norm(vec)
            return vec / norm if norm > 0 else vec
        except Exception as e:
            err = str(e)
            if "429" in err or "quota" in err.lower() or "rate" in err.lower():
                wait = min(2 ** attempt, 30)
                print(f"  [Rate limit] Waiting {wait}s...")
                time.sleep(wait)
            else:
                print(f"  [Gemini error] {e}")
                time.sleep(2)
    return np.zeros(GEMINI_DIM, dtype=np.float32)


# ─────────────────────────────────────────────────────────────────────────────
# GEMINI RAG COSINE SEARCH (HNSW)
# ─────────────────────────────────────────────────────────────────────────────

def rag_search(rag_data: dict, q_float: np.ndarray, top_k: int = 10):
    try:
        import hnswlib
    except ImportError:
        print("[ERROR] hnswlib not installed. pip install hnswlib")
        return []

    ids  = rag_data["item_ids"]
    dim  = rag_data["dim"]
    hnsw = hnswlib.Index(space="cosine", dim=dim)
    hnsw.load_index(str(HNSW_PATH), max_elements=len(ids))
    labels, distances = hnsw.knn_query(q_float.reshape(1, -1), k=top_k)
    return [(ids[lbl], float(dist)) for lbl, dist in zip(labels[0], distances[0])]


# ─────────────────────────────────────────────────────────────────────────────
# MAIN DEMO
# ─────────────────────────────────────────────────────────────────────────────

def run_query(query_text: str, top_k: int, use_rag: bool, gemini_key: str):
    print(f"\n  Loading NodeMind index...")
    with open(NM_INDEX_PATH, "rb") as f:
        nm_data = pickle.load(f)
    id_to_item = {item["item_id"]: item for item in nm_data["corpus"]}

    if use_rag:
        print(f"  Loading Gemini RAG index...")
        with open(RAG_INDEX_PATH, "rb") as f:
            rag_data = pickle.load(f)
        gclient = load_gemini(gemini_key)
        if gclient is None:
            print("  [WARN] Gemini unavailable. Running NodeMind only.")
            use_rag = False

    print(f"\n  Embedding query with BGE-M3...")
    t0 = time.time()
    q_bge = embed_text_bge(query_text)
    print(f"  BGE embed: {(time.time()-t0)*1000:.1f} ms")

    if use_rag:
        print(f"  Embedding query with Gemini...")
        t0 = time.time()
        q_gemini = embed_text_gemini(query_text, gclient)
        print(f"  Gemini embed: {(time.time()-t0)*1000:.1f} ms")

    print()
    print("=" * 78)
    print(f"  QUERY: \"{query_text}\"")
    print("=" * 78)

    variant_labels = {
        "nm_256":  "NodeMind NM-256  (128× compression)",
        "nm_512":  "NodeMind NM-512  ( 64× compression)",
        "nm_1024": "NodeMind NM-1024 ( 32× compression)",
    }

    for vname, label in variant_labels.items():
        print(f"\n  ── {label} ──")
        t0 = time.time()
        results = nm_search(nm_data, q_bge, vname, top_k=top_k)
        print(f"  Search time: {(time.time()-t0)*1000:.2f} ms")
        for rank, (item_id, dist) in enumerate(results, 1):
            item = id_to_item.get(item_id, {"item_id": item_id})
            print(f"  #{rank:<2} [hamming={dist:>4}] {item_display(item)}")

    if use_rag:
        print(f"\n  ── Gemini RAG float32 HNSW (1× — baseline) ──")
        t0 = time.time()
        results = rag_search(rag_data, q_gemini, top_k=top_k)
        print(f"  Search time: {(time.time()-t0)*1000:.2f} ms")
        for rank, (item_id, dist) in enumerate(results, 1):
            item = id_to_item.get(item_id, {"item_id": item_id})
            print(f"  #{rank:<2} [dist={dist:.4f}] {item_display(item)}")

    print()


def interactive_loop(top_k: int, use_rag: bool, gemini_key: str):
    print("\n" + "=" * 78)
    print("  NodeMind Multimodal Benchmark — Interactive Query Demo")
    print("=" * 78)
    print(f"  Corpus: 200 items (text, images, audio, tables, code)")
    print(f"  Indexes: NM-256 / NM-512 / NM-1024" + (" + Gemini RAG" if use_rag else ""))
    print(f"  Top-K: {top_k}")
    print(f"\n  Type a query and press Enter.  Ctrl+C or 'quit' to exit.\n")

    while True:
        try:
            query = input("  Query> ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\n  Bye.")
            break
        if not query:
            continue
        if query.lower() in ("quit", "exit", "q"):
            print("  Bye.")
            break
        run_query(query, top_k, use_rag, gemini_key)


def main():
    parser = argparse.ArgumentParser(
        description="NodeMind Multimodal Benchmark — Query Demo",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument("--query", "-q", type=str, default=None,
                        help="Run a single query (non-interactive).")
    parser.add_argument("--top_k", "-k", type=int, default=5,
                        help="Number of results to return (default: 5)")
    parser.add_argument("--no-rag", action="store_true",
                        help="Skip Gemini RAG baseline (no API key needed)")
    parser.add_argument("--gemini-key", type=str, default=None,
                        help="Gemini API key (defaults to GEMINI_API_KEY env var)")
    args = parser.parse_args()

    gemini_key = args.gemini_key or os.environ.get("GEMINI_API_KEY", "")
    use_rag    = (not args.no_rag) and bool(gemini_key)

    if not use_rag and not args.no_rag:
        print("[INFO] No GEMINI_API_KEY found — running NodeMind only.")
        print("       Set GEMINI_API_KEY or use --no-rag to suppress this message.\n")

    missing = []
    for p in [NM_INDEX_PATH, CORPUS_DIR / "corpus.json", HERE / "_nm_core.pyc"]:
        if not p.exists():
            missing.append(str(p))
    if use_rag:
        for p in [RAG_INDEX_PATH, HNSW_PATH]:
            if not p.exists():
                missing.append(str(p))
    if missing:
        print("[ERROR] Missing required files:")
        for m in missing:
            print(f"  • {m}")
        sys.exit(1)

    if args.query:
        run_query(args.query, args.top_k, use_rag, gemini_key)
    else:
        interactive_loop(args.top_k, use_rag, gemini_key)


if __name__ == "__main__":
    main()
