import ast
import json
import os
import re
from datetime import datetime
import time
from typing import Any, Dict, List, Optional, Set
from urllib import request
from urllib.error import HTTPError

import kyc_policy_engine_llm_v2 as kyc

_GEMINI_API_KEY_ENV = ("GEMINI_API_KEY", "GOOGLE_API_KEY", "KYC_GEMINI_API_KEY")
_DEFAULT_GEMINI_MODEL = "gemini-2.5-flash"
_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
_MAX_OUTPUT_TOKENS = int(os.getenv("KYC_OCR_MAX_OUTPUT_TOKENS", "4096"))
_MAX_RETRIES = int(os.getenv("KYC_OCR_MAX_RETRIES", "2"))
_RETRY_BASE_SLEEP = float(os.getenv("KYC_OCR_RETRY_BASE_SLEEP", "0.5"))
_GEMINI_DEBUG = os.getenv("KYC_GEMINI_DEBUG", "0") in {"1", "true", "TRUE", "yes", "on"}


def _debug_log(event: str, detail: str, payload: Optional[Dict[str, Any]] = None) -> None:
    if not _GEMINI_DEBUG:
        return
    try:
        path = "/tmp/gcv_gemini_debug.log"
        with open(path, "a", encoding="utf-8") as f:
            f.write(f"[{datetime.now().isoformat()}] {event}: {detail}\n")
            if payload is not None:
                try:
                    f.write(json.dumps(payload, ensure_ascii=False)[:50000] + "\n")
                except Exception:
                    f.write(f"{str(payload)[:50000]}\n")
            f.write("-" * 80 + "\n")
    except Exception:
        pass


def _load_dotenv() -> None:
    """Load .env values into environment if missing."""
    for base in (os.path.dirname(os.path.abspath(__file__)), "/var/www/html"):
        path = os.path.join(base, ".env")
        if not os.path.isfile(path):
            continue
        try:
            with open(path, "r", encoding="utf-8") as f:
                for raw_line in f:
                    line = raw_line.strip()
                    if not line or line.startswith("#"):
                        continue
                    if line.startswith("export "):
                        line = line[len("export "):].strip()
                    if "=" not in line:
                        continue
                    key, value = line.split("=", 1)
                    key = key.strip()
                    if not key:
                        continue
                    value = value.strip().strip().strip("\"'")
                    if key and key not in os.environ:
                        os.environ[key] = value
        except OSError:
            # ignore unreadable .env files to avoid hard failure in CLI/service contexts
            continue


def _get_gemini_api_key() -> str:
    _load_dotenv()
    for env_name in _GEMINI_API_KEY_ENV:
        key = os.getenv(env_name)
        if key and key.strip():
            return key.strip()
    raise RuntimeError("Gemini API key is not set. Set GEMINI_API_KEY, GOOGLE_API_KEY or KYC_GEMINI_API_KEY.")


def _get_gemini_model() -> str:
    _load_dotenv()
    return os.getenv("KYC_GEMINI_MODEL") or os.getenv("KYC_OCR_MODEL") or _DEFAULT_GEMINI_MODEL


def _extract_text_like_fields(node: Any, texts: List[str], seen: Optional[Set[int]] = None) -> None:
    if seen is None:
        seen = set()

    if isinstance(node, dict):
        oid = id(node)
        if oid in seen:
            return
        seen.add(oid)
        for key, value in node.items():
            if key == "text" and isinstance(value, str):
                texts.append(value)
                continue
            if isinstance(value, (dict, list)):
                _extract_text_like_fields(value, texts, seen)
        return

    if isinstance(node, list):
        oid = id(node)
        if oid in seen:
            return
        seen.add(oid)
        for item in node:
            _extract_text_like_fields(item, texts, seen)


def _extract_json_strict(s: str) -> Dict[str, Any]:
    s = (s or "").strip()
    if not s:
        raise ValueError("Empty model output")
    candidates: List[str] = []

    # remove common model wrappers (```json ... ```)
    fence_pattern = re.compile(r"```(?:json|JSON)?\s*(.*?)\s*```", re.S)
    for match in fence_pattern.finditer(s):
        chunk = (match.group(1) or "").strip()
        if chunk:
            candidates.append(chunk)

    # use every balanced JSON-like fragment in text
    candidates.append(s)
    for start in [i for i, ch in enumerate(s) if ch == "{"]:
        candidates.append(s[start:])

    decoder = json.JSONDecoder()

    def _strip_trailing_commas(text: str) -> str:
        # tolerate ",}" and ",]" style
        return re.sub(r",\s*([}\]])", r"\1", text)

    def _is_payload_like(obj: Any) -> bool:
        if not isinstance(obj, dict):
            return False
        if "fields" in obj:
            return True
        if obj and "error" in obj and set(obj.keys()).issubset({"error", "code", "message", "status"}):
            return False
        if not obj:
            return False
        # Legacy/object-only payloads: {"field": {"text": ...}}
        if all(isinstance(k, str) and isinstance(v, dict) for k, v in obj.items()):
            return True
        # Loose payloads: {"field":"value"} or {"field":null}
        # (normalizer can adapt later)
        return True

    def _as_payload(obj: Any) -> Optional[Dict[str, Any]]:
        if isinstance(obj, dict):
            return obj if _is_payload_like(obj) else None
        if isinstance(obj, list) and len(obj) == 1 and isinstance(obj[0], dict):
            return _as_payload(obj[0])
        return None

    def _repair_candidate(cand: str) -> List[str]:
        # try progressively cleaned variants
        cleaned = cand.strip()
        cleaned = re.sub(r"^\uFEFF", "", cleaned)
        return [
            cleaned,
            _strip_trailing_commas(cleaned),
        ]

    for cand in candidates:
        # If fenced snippet contains extra text outside object, try raw decoder first
        for variant in _repair_candidate(cand):
            try:
                obj = json.loads(variant)
                payload = _as_payload(obj)
                if payload is not None:
                    return payload
            except Exception:
                pass

            # Try parsing JSON that starts from first '{' and is valid prefix.
            if "{" in variant:
                for i, ch in enumerate(variant):
                    if ch != "{":
                        continue
                    try:
                        obj, _ = decoder.raw_decode(variant[i:])
                        payload = _as_payload(obj)
                        if payload is not None:
                            return payload
                    except Exception:
                        continue

            # Final permissive fallback for JSON-like output with Python literals
            try:
                obj = ast.literal_eval(variant)
                payload = _as_payload(obj)
                if payload is not None:
                    return payload
            except Exception:
                pass

    raise ValueError("Model output is not valid JSON")


def _extract_text_from_gemini_response(payload: Dict[str, Any]) -> str:
    candidates = payload.get("candidates") or []
    if not isinstance(candidates, list):
        candidates = []

    if not candidates:
        extra_texts: List[str] = []
        _extract_text_like_fields(payload, extra_texts)
        return "".join(extra_texts).strip()

    texts: List[str] = []
    for candidate in candidates:
        if not isinstance(candidate, dict):
            continue

        content = candidate.get("content")
        if isinstance(content, dict):
            parts = content.get("parts") or []
            if isinstance(parts, list):
                for p in parts:
                    if not isinstance(p, dict):
                        continue
                    t = p.get("text")
                    if isinstance(t, str):
                        texts.append(t)

        direct_text = candidate.get("text")
        if isinstance(direct_text, str):
            texts.append(direct_text)

        if isinstance(content, str):
            texts.append(content)

    if not texts:
        extra_texts: List[str] = []
        _extract_text_like_fields(payload, extra_texts)
        if extra_texts:
            return "".join(extra_texts).strip()

    return "".join(texts).strip()


def _normalize_output_fields(fields_plan: List[Dict[str, Any]], data: Dict[str, Any]) -> Dict[str, Any]:
    fields = data.get("fields")
    if not isinstance(fields, dict):
        fields = dict(data) if isinstance(data, dict) else {}

    for f in fields_plan:
        key = f.get("key")
        if not isinstance(key, str) or not key:
            continue

        v = fields.get(key)
        if isinstance(v, dict):
            value_box = v
        else:
            value_box = {
                "candidate_idx": 0,
                "text": None,
                "confidence": 0.0,
                "normalized": None,
            }
            if v is not None:
                value_box["text"] = str(v).strip() if str(v).strip() else None
            v = value_box
            fields[key] = value_box

        # candidate_idx
        try:
            v["candidate_idx"] = int(v.get("candidate_idx", 0))
        except Exception:
            v["candidate_idx"] = 0

        # text
        t = v.get("text")
        if t is None:
            v["text"] = None
        else:
            v["text"] = str(t).strip() if str(t).strip() else None

        # confidence
        c = v.get("confidence")
        if c is None:
            v["confidence"] = 0.0
        else:
            try:
                cc = float(c)
                v["confidence"] = max(0.0, min(1.0, cc))
            except Exception:
                v["confidence"] = 0.0

        # normalized
        n = v.get("normalized")
        if n is None:
            v["normalized"] = None
        else:
            n = str(n).strip()
            v["normalized"] = None if not n else n

    # fill missing keys to keep caller contracts stable
    for f in fields_plan:
        key = f.get("key")
        if key not in fields:
            fields[key] = {"candidate_idx": 0, "text": None, "confidence": 0.0, "normalized": None}

    return {"fields": fields}


def llm_ocr(bundle: Dict[str, Any]) -> Dict[str, Any]:
    fields_plan = bundle.get("fields") if isinstance(bundle, dict) else []
    if not isinstance(fields_plan, list):
        fields_plan = []

    parts: List[Dict[str, Any]] = kyc.KYCEngine.gemini_parts_from_bundle(bundle or {})
    if not parts:
        return {"fields": {}}

    api_key = _get_gemini_api_key()
    model = _get_gemini_model()
    url = f"{_GEMINI_ENDPOINT.format(model=model)}?key={api_key}"

    payload = {
        "contents": [
            {
                "role": "user",
                "parts": parts,
            }
        ],
        "generationConfig": {
            "temperature": 0.0,
            "maxOutputTokens": _MAX_OUTPUT_TOKENS,
            "responseMimeType": "application/json",
        },
    }

    last_err: Optional[Exception] = None
    max_output_tokens = _MAX_OUTPUT_TOKENS
    for attempt in range(_MAX_RETRIES + 1):
        try:
            payload["generationConfig"]["maxOutputTokens"] = max_output_tokens
            req = request.Request(
                url,
                method="POST",
                data=json.dumps(payload).encode("utf-8"),
                headers={
                    "Content-Type": "application/json",
                    "Accept": "application/json",
                    "User-Agent": "kyc-gemini-ocr/1.0",
                },
            )
            with request.urlopen(req, timeout=30) as resp:
                raw = resp.read().decode("utf-8", errors="replace")

            data = json.loads(raw)
            _debug_log("gemini_raw_response", "response parsed", {"raw_len": len(raw), "keys": list(data.keys()) if isinstance(data, dict) else []})
            text = _extract_text_from_gemini_response(data)
            _debug_log("gemini_text_extract", "extracted text", {"text": text[:4000] if isinstance(text, str) else ""})
            if not text:
                _debug_log("gemini_no_text", "Gemini response has no text part", data)
            result_obj = _extract_json_strict(text)
            return _normalize_output_fields(fields_plan, result_obj)

        except ValueError as e:
            # Malformed model output (non-JSON). Keep system alive and return
            # empty fields after retries instead of hard-failing the whole flow.
            last_err = e
            _debug_log("gemini_parse_error", "JSON parse error", {"error": str(e), "attempt": attempt})
            if attempt < _MAX_RETRIES:
                if max_output_tokens < 8192:
                    max_output_tokens = min(8192, max_output_tokens * 2)
                time.sleep(_RETRY_BASE_SLEEP * (2 ** attempt))
                continue
            return _normalize_output_fields(fields_plan, {"fields": {}})

        except HTTPError as e:
            body = ""
            try:
                body = e.read().decode("utf-8", errors="replace")
            except Exception:
                pass
            last_err = RuntimeError(f"Gemini HTTP {e.code}: {e.reason}; body={body}")
            if attempt < _MAX_RETRIES and e.code in {429, 500, 502, 503, 504}: 
                time.sleep(_RETRY_BASE_SLEEP * (2 ** attempt))
                continue
            raise last_err

        except Exception as e:
            last_err = e
            if attempt < _MAX_RETRIES:
                time.sleep(_RETRY_BASE_SLEEP * (2 ** attempt))
                continue
            raise RuntimeError(f"Gemini OCR failed: {e}") from e

    raise RuntimeError(f"Gemini OCR failed after retries: {last_err}") from last_err
