# llm_ocr_openai.py

import os
import json
import time
from typing import Any, Dict, Optional

import kyc_policy_engine_llm_v2 as kyc

try:
    from openai import OpenAI
    from openai import RateLimitError, APITimeoutError, APIConnectionError, InternalServerError
except Exception:  # pragma: no cover
    OpenAI = None  # type: ignore
    RateLimitError = APITimeoutError = APIConnectionError = InternalServerError = Exception  # type: ignore


_DEFAULT_MODEL = os.getenv("KYC_OCR_MODEL", "gpt-5.2")
_MODEL_FALLBACKS = [
    "gpt-5.2",
]
_KNOWN_UNSUPPORTED_SUBSTR = (
    "does not exist",
    "is not valid for this API",
    "model is not supported",
    "model is not available",
    "unknown model",
    "unsupported reasoning",
    "invalid reasoning",
    "temperature",
    "does not support temperature",
)

_MAX_OUTPUT_TOKENS = int(os.getenv("KYC_OCR_MAX_OUTPUT_TOKENS", "2200"))
_MAX_RETRIES = int(os.getenv("KYC_OCR_MAX_RETRIES", "2"))
_RETRY_BASE_SLEEP = float(os.getenv("KYC_OCR_RETRY_BASE_SLEEP", "0.4"))
_OPENAI_TIMEOUT = float(os.getenv("KYC_OCR_OPENAI_TIMEOUT", "90"))
_GPT5_REASONING_EFFORT = os.getenv("KYC_OCR_GPT5_REASONING_EFFORT", "xhigh").strip().lower()
_DEBUG = os.getenv("KYC_OCR_DEBUG", "0").strip().lower() in ("1", "true", "yes", "on")

_GPT5_ALLOWED_REASONING_EFFORT = {
    "none",
    "minimal",
    "low",
    "medium",
    "high",
    "xhigh",
}


# cache client lazily (safe for repeated calls)
_OAI_CLIENT: Optional["OpenAI"] = None


def _normalize_model_candidates() -> list[str]:
    seen = set()
    candidates = []
    for model in _MODEL_FALLBACKS:
        if not model or not isinstance(model, str):
            continue
        if model in seen:
            continue
        seen.add(model)
        candidates.append(model)
    if not candidates:
        candidates = ["gpt-5.2", "gpt-5-mini", "gpt-5-nano", "gpt-5", "gpt-4o-mini"]
    return candidates


def _is_gpt5_family(model: str) -> bool:
    return bool(model and model.startswith("gpt-5"))


def _build_request_kwargs(model: str, content_items: Any, reasoning_effort: Optional[str] = None) -> Dict[str, Any]:
    kwargs: Dict[str, Any] = {
        "model": model,
        "input": [{"role": "user", "content": content_items}],
        "text": {"format": {"type": "json_object"}},
        "max_output_tokens": _MAX_OUTPUT_TOKENS,
        "store": False,
    }
    if _is_gpt5_family(model) and reasoning_effort:
        kwargs["reasoning"] = {"effort": reasoning_effort}
    return kwargs


def _reasoning_chain() -> list[Optional[str]]:
    if _GPT5_REASONING_EFFORT not in _GPT5_ALLOWED_REASONING_EFFORT:
        return [None]

    chain = [ _GPT5_REASONING_EFFORT ]
    if _GPT5_REASONING_EFFORT != "high":
        chain.append("high")
    chain.append(None)
    # deduplicate while preserving order
    uniq: list[Optional[str]] = []
    for v in chain:
        if v not in uniq:
            uniq.append(v)
    return uniq


def _is_model_access_error(err: Exception) -> bool:
    msg = str(err).lower()
    return any(fragment in msg for fragment in _KNOWN_UNSUPPORTED_SUBSTR)


def _debug_log(*parts: Any) -> None:
    if _DEBUG:
        print("[KYC-OCR-DEBUG]", *parts)


def _collect_text(node: Any, out: list[str]) -> None:
    """Recursively collect text-like fields from SDK objects and dict/list payloads."""
    if isinstance(node, str):
        txt = node.strip()
        if txt:
            out.append(txt)
        return

    if isinstance(node, dict):
        # Fast-path common SDK key names.
        v = node.get("output_text")
        if isinstance(v, str):
            out.append(v)
        v = node.get("text")
        if isinstance(v, str):
            out.append(v)
        elif isinstance(v, dict):
            maybe = v.get("value")
            if isinstance(maybe, str):
                out.append(maybe)
            maybe = v.get("text")
            if isinstance(maybe, str):
                out.append(maybe)
        for k in ("content", "output", "choices", "messages", "response", "result"):
            if k in node:
                _collect_text(node[k], out)
        return

    if isinstance(node, (list, tuple)):
        for it in node:
            _collect_text(it, out)


def _get_client() -> "OpenAI":
    global _OAI_CLIENT
    if _OAI_CLIENT is not None:
        return _OAI_CLIENT

    if OpenAI is None:
        raise RuntimeError("openai package not available")

    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        # Do NOT crash at import time; fail only when OCR is actually called.
        raise RuntimeError("OPENAI_API_KEY is not set")

    _OAI_CLIENT = OpenAI(api_key=api_key, timeout=_OPENAI_TIMEOUT)
    return _OAI_CLIENT


def _response_to_text(resp: Any) -> str:
    if resp is None:
        return ""

    # direct helper (common fast path in openai SDK)
    txt = getattr(resp, "output_text", None)
    if isinstance(txt, str):
        return txt.strip()

    # fallback: inspect parsed JSON/pydantic output payload
    payload = None
    if hasattr(resp, "model_dump"):
        try:
            payload = resp.model_dump()
        except Exception:
            payload = None

    if payload is None:
        payload = getattr(resp, "__dict__", None)

    if not isinstance(payload, dict):
        return ""

    text = payload.get("output_text")
    if isinstance(text, str):
        return text.strip()

    # generic walk for response schema differences
    txt_chunks: list[str] = []
    _collect_text(payload, txt_chunks)
    if txt_chunks:
        return "\n".join(txt_chunks).strip()

    output = payload.get("output")
    if isinstance(output, list):
        chunks = []
        for item in output:
            content = item.get("content") if isinstance(item, dict) else None
            if not isinstance(content, list):
                continue
            for c in content:
                if not isinstance(c, dict):
                    continue
                if c.get("type") != "output_text":
                    continue
                txt = c.get("text")
                if isinstance(txt, dict):
                    txt = txt.get("value") or txt.get("text")
                if isinstance(txt, str):
                    chunks.append(txt)
        if chunks:
            return "".join(chunks).strip()
    return ""


def _content_items_to_chat_message(content_items: Any) -> list:
    user_content = []
    if not isinstance(content_items, list):
        return user_content

    for item in content_items:
        if not isinstance(item, dict):
            continue
        t = str(item.get("type") or "")
        if t in ("input_text", "text"):
            txt = item.get("text", "")
            if isinstance(txt, str):
                user_content.append({"type": "text", "text": txt})
        elif t == "input_image":
            image_url = item.get("image_url")
            if isinstance(image_url, str) and image_url.strip():
                user_content.append({
                    "type": "image_url",
                    "image_url": {"url": image_url},
                })

    return user_content


def _call_chat_fallback(client: "OpenAI", model: str, content_items: Any) -> str:
    user_content = _content_items_to_chat_message(content_items)
    if not user_content:
        return ""

    try:
        messages = [
            {"role": "user", "content": user_content},
        ]
        chat_resp = client.chat.completions.create(
            model=model,
            messages=messages,
            response_format={"type": "json_object"},
            max_tokens=_MAX_OUTPUT_TOKENS,
        )
        if hasattr(chat_resp, "choices") and chat_resp.choices:
            c0 = chat_resp.choices[0]
            msg = getattr(c0, "message", None)
            if msg is not None:
                txt = getattr(msg, "content", None)
                if isinstance(txt, str):
                    return txt.strip()
    except Exception as e:
        # keep chat fallback optional. If it fails, bubble as empty for normal retry/error path.
        _debug_log("chat fallback failed:", repr(e))
        return ""
    return ""


def _extract_json_strict(s: str) -> Dict[str, Any]:
    s = (s or "").strip()
    if not s:
        raise ValueError("Empty model output")

    try:
        return json.loads(s)
    except Exception:
        pass

    start = s.find("{")
    end = s.rfind("}")
    if start >= 0 and end > start:
        return json.loads(s[start:end + 1])

    raise ValueError("Model output is not valid JSON")


def llm_ocr(bundle: Dict[str, Any]) -> Dict[str, Any]:
    client = _get_client()
    content_items = kyc.KYCEngine.openai_content_items_from_bundle(bundle)

    # deterministic: keep server-side defaults for the selected model (remove explicit temperature for GPT-5 family)
    model_candidates = _normalize_model_candidates()
    last_model_error: Optional[Exception] = None
    last_retry_error: Optional[Exception] = None
    reasoning_chain = _reasoning_chain()

    for model in model_candidates:
        for reasoning_effort in reasoning_chain:
            for attempt in range(_MAX_RETRIES + 1):
                try:
                    req = _build_request_kwargs(model, content_items, reasoning_effort)
                    _debug_log(
                        "OCR request",
                        {
                            "model": model,
                            "attempt": attempt + 1,
                            "reasoning": reasoning_effort,
                            "items": len(content_items),
                        },
                    )
                    resp = client.responses.create(**req)
                    raw_text = _response_to_text(resp)
                    if not raw_text:
                        _debug_log("responses output empty, trying chat fallback", {"model": model})
                        raw_text = _call_chat_fallback(client, model, content_items)
                        if not raw_text:
                            _debug_log("chat fallback output empty", {"model": model})

                    data = _extract_json_strict(raw_text)
                    fields = data.get("fields")
                    if not isinstance(fields, dict):
                        raise ValueError("Missing/invalid 'fields' in OCR response")

                    # normalize output shape
                    for k, v in list(fields.items()):
                        if not isinstance(v, dict):
                            fields[k] = {"candidate_idx": 0, "text": "", "confidence": 0.0, "normalized": None}
                            continue

                        try:
                            v["candidate_idx"] = int(v.get("candidate_idx", 0))
                        except Exception:
                            v["candidate_idx"] = 0

                        t = v.get("text", "")
                        v["text"] = "" if t is None else str(t)

                        c = v.get("confidence", None)
                        if c is None:
                            v["confidence"] = None
                        else:
                            try:
                                cc = float(c)
                                v["confidence"] = max(0.0, min(1.0, cc))
                            except Exception:
                                v["confidence"] = None

                        n = v.get("normalized", None)
                        v["normalized"] = None if n is None else str(n)

                    return {"fields": fields}

                except (RateLimitError, APITimeoutError, APIConnectionError, InternalServerError) as e:
                    last_retry_error = e
                    _debug_log("retryable error:", type(e).__name__, str(e), "model=", model, "attempt=", attempt + 1)
                    if attempt < _MAX_RETRIES:
                        time.sleep(_RETRY_BASE_SLEEP * (2 ** attempt))
                        continue
                    break

                except ValueError as e:
                    last_retry_error = e
                    if str(e).startswith("Empty model output"):
                        _debug_log("empty output, retrying", {"model": model, "reasoning": reasoning_effort, "attempt": attempt + 1})
                    if attempt < _MAX_RETRIES:
                        time.sleep(_RETRY_BASE_SLEEP * (2 ** attempt))
                        continue
                    break

                except Exception as e:
                    if _is_model_access_error(e):
                        last_model_error = e
                        break
                    _debug_log("non-retryable OCR error:", type(e).__name__, str(e), "model=", model)
                    # On likely reasoning-compatibility failures, try next effort if configured.
                    msg = str(e).lower()
                    if reasoning_chain.index(reasoning_effort) < len(reasoning_chain) - 1 and "reasoning" in msg:
                        break
                    raise RuntimeError(f"llm_ocr non-retryable failure: {e}") from e

    if last_model_error is not None:
        raise RuntimeError(f"llm_ocr model not available: {last_model_error}") from last_model_error
    if last_retry_error is not None:
        raise RuntimeError(f"llm_ocr retryable failure: unable to get successful response after fallback attempts: {last_retry_error}") from last_retry_error
    raise RuntimeError("llm_ocr failed without receiving a valid response")
