
# kyc_policy_engine.py (LLM OCR edition)
# ------------------------------------------------------------
# Trendo KYC: YOLO(OBB) -> crop fields -> ONE LLM request for OCR -> rule-based validation + (on/off) face match
#
# IMPORTANT:
# - This file intentionally REMOVES EasyOCR and all "manual OCR" logic.
# - Implement `llm_ocr(request_bundle) -> response_bundle` on your server and inject into run_kyc/run_kyc_auto_variant.
#
from __future__ import annotations

import os
import re
import math
import json
import base64
import threading
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional, Tuple, Iterable, Callable

import cv2
import numpy as np
import torch
from ultralytics import YOLO

# ============================================================
# 1) Light normalization (ONLY for server-side matching/validation)
#    - OCR is handled by LLM, but we keep robust string normalization
#      so matching does not break on minor Arabic/Persian variants.
# ============================================================

_FA_AR_TO_EN_DIGITS = str.maketrans({
    "۰": "0", "۱": "1", "۲": "2", "۳": "3", "۴": "4",
    "۵": "5", "۶": "6", "۷": "7", "۸": "8", "۹": "9",
    "٠": "0", "١": "1", "٢": "2", "٣": "3", "٤": "4",
    "٥": "5", "٦": "6", "٧": "7", "٨": "8", "٩": "9",
})

_AR_FA_LETTERS = str.maketrans({
    "ي": "ی",
    "ى": "ی",
    "ك": "ک",
    "ة": "ه",
    "ۀ": "ه",
    "ؤ": "و",
    "إ": "ا",
    "أ": "ا",
    "ٱ": "ا",
    "ء": "",
    "ئ": "ی",
    "ﻻ": "لا",
    "ـ": "",   # tatweel
})

_DIACRITICS_RE = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
_ZWNJ = "\u200c"


def collapse_spaces(s: str) -> str:
    return re.sub(r"\s+", " ", (s or "").strip())


def to_en_digits(s: str) -> str:
    return (s or "").translate(_FA_AR_TO_EN_DIGITS)


def normalize_fa_text(s: str) -> str:
    s = (s or "")
    s = to_en_digits(s)
    s = s.translate(_AR_FA_LETTERS)
    s = s.replace(_ZWNJ, " ")
    s = _DIACRITICS_RE.sub("", s)
    return collapse_spaces(s)


def normalize_name(s: str) -> str:
    s = normalize_fa_text(s)
    s = re.sub(r"[^\w\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\s\-]", "", s, flags=re.UNICODE)
    return collapse_spaces(s).lower()


def normalize_digits_only(s: str) -> str:
    s = to_en_digits(s)
    return re.sub(r"[^0-9]", "", s)


def normalize_id_alphanum(s: str) -> str:
    s = to_en_digits(s)
    return re.sub(r"[^0-9A-Za-z]", "", s).upper()


# ============================================================
# 2) Similarity + loose date parsing (for matching)
# ============================================================

def levenshtein_ratio(a: str, b: str) -> float:
    a = a or ""
    b = b or ""
    if a == b:
        return 1.0
    if not a or not b:
        return 0.0

    prev = list(range(len(b) + 1))
    for i, ca in enumerate(a, start=1):
        cur = [i]
        for j, cb in enumerate(b, start=1):
            cost = 0 if ca == cb else 1
            cur.append(min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost))
        prev = cur

    dist = prev[-1]
    return 1.0 - (dist / max(len(a), len(b)))


def token_set_similarity(a: str, b: str) -> float:
    a = normalize_name(a)
    b = normalize_name(b)
    if not a or not b:
        return 0.0
    sa = set(a.split())
    sb = set(b.split())
    inter = len(sa & sb)
    union = len(sa | sb)
    j = inter / union if union else 0.0
    return 0.6 * j + 0.4 * levenshtein_ratio(a, b)


def parse_date_loose(s: str) -> Optional[Tuple[int, int, int]]:
    """
    Accepts common formats and returns (YYYY, M, D) if parseable.
    This is only for validation/matching, not for OCR.
    """
    s = to_en_digits(s or "")
    s = s.replace("-", "/").replace(".", "/")
    nums = re.findall(r"\d{1,4}", s)
    if len(nums) < 3:
        return None

    vals = [int(x) for x in nums[:3]]
    perms = [
        (vals[0], vals[1], vals[2]),
        (vals[0], vals[2], vals[1]),
        (vals[1], vals[0], vals[2]),
        (vals[1], vals[2], vals[0]),
        (vals[2], vals[0], vals[1]),
        (vals[2], vals[1], vals[0]),
    ]
    for y, m, d in perms:
        if 1000 <= y <= 2200 and 1 <= m <= 12 and 1 <= d <= 31:
            return (y, m, d)
    return None


# ============================================================
# 3) Iran National Code checksum (validator only)
# ============================================================

def iran_national_code_is_valid(code: str) -> bool:
    code = normalize_digits_only(code)
    if len(code) != 10:
        return False
    if len(set(code)) == 1:
        return False

    digits = [int(c) for c in code]
    s = sum(digits[i] * (10 - i) for i in range(9))
    r = s % 11
    check = digits[9]
    if r < 2:
        return check == r
    return check == (11 - r)


# ============================================================
# 4) OBB crop utilities
# ============================================================

def order_points_clockwise(pts: np.ndarray) -> np.ndarray:
    pts = pts.astype(np.float32)
    s = pts.sum(axis=1)
    tl = pts[np.argmin(s)]
    br = pts[np.argmax(s)]
    d = np.diff(pts, axis=1).reshape(-1)
    tr = pts[np.argmin(d)]
    bl = pts[np.argmax(d)]
    return np.array([tl, tr, br, bl], dtype=np.float32)


def clip_points(pts: np.ndarray, w: int, h: int) -> np.ndarray:
    pts[:, 0] = np.clip(pts[:, 0], 0, w - 1)
    pts[:, 1] = np.clip(pts[:, 1], 0, h - 1)
    return pts


def warp_quad_to_rect(img_bgr: np.ndarray, quad: np.ndarray) -> np.ndarray:
    H, W = img_bgr.shape[:2]
    quad = quad.astype(np.float32)
    quad = clip_points(quad, W, H)
    quad = order_points_clockwise(quad)

    tl, tr, br, bl = quad
    widthA = np.linalg.norm(br - bl)
    widthB = np.linalg.norm(tr - tl)
    maxW = int(max(widthA, widthB))
    heightA = np.linalg.norm(tr - br)
    heightB = np.linalg.norm(tl - bl)
    maxH = int(max(heightA, heightB))

    maxW = max(maxW, 2)
    maxH = max(maxH, 2)

    dst = np.array([[0, 0], [maxW - 1, 0], [maxW - 1, maxH - 1], [0, maxH - 1]], dtype=np.float32)
    M = cv2.getPerspectiveTransform(quad, dst)
    return cv2.warpPerspective(img_bgr, M, (maxW, maxH), flags=cv2.INTER_CUBIC)


def _resize_max_dim(img: np.ndarray, max_dim: int = 900) -> np.ndarray:
    h, w = img.shape[:2]
    m = max(h, w)
    if m <= max_dim:
        return img
    scale = max_dim / float(m)
    nw = max(2, int(round(w * scale)))
    nh = max(2, int(round(h * scale)))
    return cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)


def _encode_jpeg_data_url(img_bgr: np.ndarray, jpeg_quality: int = 85, max_dim: int = 900) -> str:
    """
    Converts an image to a compact data URL to feed to an LLM.
    """
    img_bgr = _resize_max_dim(img_bgr, max_dim=max_dim)
    ok, buf = cv2.imencode(".jpg", img_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)])
    if not ok:
        raise RuntimeError("cv2.imencode(.jpg) failed")
    b64 = base64.b64encode(buf.tobytes()).decode("ascii")
    return f"data:image/jpeg;base64,{b64}"


# ============================================================
# 5) Face matcher (ON/OFF) - InsightFace
# ============================================================

class FaceMatcherBase:
    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        raise NotImplementedError()


class FaceMatcherUnavailable(FaceMatcherBase):
    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        return {"score01": None, "cosine": None, "reason": "face_module_unavailable", "details": None}


class InsightFaceMatcher(FaceMatcherBase):
    def __init__(self, det_size: Tuple[int, int] = (640, 640), prefer_gpu: bool = False):
        from insightface.app import FaceAnalysis  # type: ignore
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if prefer_gpu else ["CPUExecutionProvider"]
        self.app = FaceAnalysis(name="buffalo_l", providers=providers)
        self.app.prepare(ctx_id=-1, det_size=det_size)

    @staticmethod
    def _largest_face(faces):
        if not faces:
            return None

        def area(f):
            x1, y1, x2, y2 = f.bbox
            return float((x2 - x1) * (y2 - y1))

        return max(faces, key=area)

    @staticmethod
    def _cosine(a: np.ndarray, b: np.ndarray) -> float:
        a = a.astype(np.float32)
        b = b.astype(np.float32)
        na = np.linalg.norm(a) + 1e-9
        nb = np.linalg.norm(b) + 1e-9
        return float(np.dot(a, b) / (na * nb))

    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        doc_faces = self.app.get(doc_photo_bgr)
        sel_faces = self.app.get(selfie_bgr)

        fd = self._largest_face(doc_faces)
        fs = self._largest_face(sel_faces)

        if fd is None:
            return {"score01": None, "cosine": None, "reason": "no_face_in_document_photo", "details": {"doc_faces": len(doc_faces)}}
        if fs is None:
            return {"score01": None, "cosine": None, "reason": "no_face_in_selfie", "details": {"selfie_faces": len(sel_faces)}}

        cos = self._cosine(fd.normed_embedding, fs.normed_embedding)
        score01 = (cos + 1.0) / 2.0
        return {"score01": float(score01), "cosine": float(cos), "reason": "ok", "details": None}


_FACE_MATCHER: Optional[FaceMatcherBase] = None
_FACE_LOCK = threading.Lock()


def get_face_matcher() -> FaceMatcherBase:
    global _FACE_MATCHER
    with _FACE_LOCK:
        if _FACE_MATCHER is not None:
            return _FACE_MATCHER
        try:
            _FACE_MATCHER = InsightFaceMatcher(prefer_gpu=False)  
        except Exception:
            _FACE_MATCHER = FaceMatcherUnavailable()
        return _FACE_MATCHER


def compute_face_pack(
    *,
    doc_image_path: str,
    selfie_image_path: str,
    doc_photo_quad: Optional[List[List[float]]],
) -> Dict[str, Any]:
    """
    Computes face match pack once. Caller decides whether to use it depending on policy.
    """
    if not selfie_image_path:
        return {"score01": None, "cosine": None, "reason": "selfie_missing", "details": None}
    if not doc_photo_quad:
        return {"score01": None, "cosine": None, "reason": "doc_photo_not_detected", "details": None}

    doc_img = cv2.imread(doc_image_path)
    sel_img = cv2.imread(selfie_image_path)
    if doc_img is None:
        return {"score01": None, "cosine": None, "reason": "cannot_read_doc_image", "details": None}
    if sel_img is None:
        return {"score01": None, "cosine": None, "reason": "cannot_read_selfie_image", "details": None}

    crop = warp_quad_to_rect(doc_img, np.array(doc_photo_quad, dtype=np.float32))
    matcher = get_face_matcher()
    return matcher.match(crop, sel_img)


# ============================================================
# 6) Config schema (server-driven)
# ============================================================

@dataclass
class FieldRule:
    class_names: List[str]
    key: str
    required: bool

    min_det_conf: float = 0.25

    # NOTE: In LLM-OCR edition, this is the minimum confidence you expect
    # the LLM to report for the extracted value (if you use it).
    min_ocr_conf: float = 0.20

    ocr_kind: str = "text"  # none|text|numeric|date|mrz

    match_type: str = "optional"  # optional|exact|fuzzy|date
    match_threshold: float = 0.85

    expected_len: int = 0

    validator: Optional[str] = None
    constraints: Dict[str, Any] = field(default_factory=dict)

    weight: float = 1.0
    must_match: bool = False
    match_gate: bool = False

    max_candidates: int = 1

    input_aliases: Optional[List[str]] = None


@dataclass
class DocConfig:
    doc_id: str
    rules: List[FieldRule]

    min_detected_fields_count: Optional[int] = None

    approve_min_coverage: float = 1.0
    approve_min_extraction: float = 0.78
    approve_min_match_core: float = 0.92
    approve_min_match_all: float = 0.0  # analytics only

    review_min_coverage: float = 0.75
    reject_below_coverage: float = 0.50

    approve_no_input_extra_buffer: float = 0.10

    # --- Face policy: ON/OFF (no optional) ---
    require_face_match: bool = False
    face_metric: str = "score01"   # score01|cosine
    face_match_threshold: float = 0.75

    # name swap is still useful for layout variants
    swap_pairs: List[Tuple[str, str]] = field(default_factory=list)
    enable_name_swap: bool = True
    name_swap_margin: float = 0.06


def doc_config_from_payload(payload: Dict[str, Any]) -> DocConfig:
    rules: List[FieldRule] = []
    for r in payload.get("rules", []):
        rr = dict(r)
        cn = rr.get("class_names") or rr.get("class_name")
        if isinstance(cn, str):
            rr["class_names"] = [cn]
        else:
            rr["class_names"] = list(cn or [])
        rr.pop("class_name", None)

        cst = rr.get("constraints")
        if cst is None:
            rr["constraints"] = {}
        elif isinstance(cst, str):
            try:
                rr["constraints"] = json.loads(cst)
            except Exception:
                rr["constraints"] = {}
        elif not isinstance(cst, dict):
            rr["constraints"] = {}

        rules.append(FieldRule(**rr))

    swap_pairs = payload.get("swap_pairs") or []
    swap_pairs = [tuple(x) for x in swap_pairs if isinstance(x, (list, tuple)) and len(x) == 2]

    return DocConfig(
        doc_id=str(payload.get("doc_id") or payload.get("id") or "doc"),
        rules=rules,
        min_detected_fields_count=payload.get("min_detected_fields_count", None),

        approve_min_coverage=float(payload.get("approve_min_coverage", 1.0)),
        approve_min_extraction=float(payload.get("approve_min_extraction", 0.78)),
        approve_min_match_core=float(payload.get("approve_min_match_core", payload.get("approve_min_match", 0.92))),
        approve_min_match_all=float(payload.get("approve_min_match_all", 0.0)),

        review_min_coverage=float(payload.get("review_min_coverage", 0.75)),
        reject_below_coverage=float(payload.get("reject_below_coverage", 0.50)),
        approve_no_input_extra_buffer=float(payload.get("approve_no_input_extra_buffer", 0.10)),

        require_face_match=bool(payload.get("require_face_match", False)),
        face_metric=str(payload.get("face_metric", "score01")),
        face_match_threshold=float(payload.get("face_match_threshold", 0.75)),

        swap_pairs=swap_pairs,
        enable_name_swap=bool(payload.get("enable_name_swap", True)),
        name_swap_margin=float(payload.get("name_swap_margin", 0.06)),
    )


# ============================================================
# 7) Engine caching (YOLO only)
# ============================================================

_ENGINE_CACHE: Dict[str, "KYCEngine"] = {}
_ENGINE_LOCK = threading.Lock()


def get_engine(model_path: str) -> "KYCEngine":
    key = str(model_path)
    with _ENGINE_LOCK:
        if key not in _ENGINE_CACHE:
            _ENGINE_CACHE[key] = KYCEngine(model_path)
        return _ENGINE_CACHE[key]


class KYCEngine:
    def __init__(self, yolo_model_path: str):
        self.model = YOLO(yolo_model_path)

        names = self.model.names
        if isinstance(names, dict):
            self.class_names = {int(k): str(v) for k, v in names.items()}
        else:
            self.class_names = {i: str(n) for i, n in enumerate(names)}

        self._infer_lock = threading.Lock()

    @staticmethod
    def _clamp01(x: float) -> float:
        return max(0.0, min(1.0, float(x)))

    def _get_expected(self, rule: FieldRule, user_input: Dict[str, str]) -> Optional[str]:
        v = user_input.get(rule.key)
        if v:
            return v
        if rule.input_aliases:
            for ak in rule.input_aliases:
                vv = user_input.get(ak)
                if vv:
                    return vv
        return None

    def detect(self, image_path: str, conf: float, iou: float, device: str, max_det: int) -> List[Dict[str, Any]]:
        with self._infer_lock:
            preds = self.model.predict(
                source=image_path, conf=conf, iou=iou, device=device, max_det=max_det, verbose=False
            )

        if not preds:
            return []

        r = preds[0]
        dets: List[Dict[str, Any]] = []

        # OBB path
        if getattr(r, "obb", None) is not None and r.obb is not None:
            quads = r.obb.xyxyxyxy
            cls = r.obb.cls
            cf = r.obb.conf

            if isinstance(quads, torch.Tensor):
                quads = quads.detach().cpu().numpy()
            if isinstance(cls, torch.Tensor):
                cls = cls.detach().cpu().numpy()
            if isinstance(cf, torch.Tensor):
                cf = cf.detach().cpu().numpy()

            for i in range(len(quads)):
                cid = int(cls[i])
                dets.append({
                    "class_id": cid,
                    "class_name": self.class_names.get(cid, str(cid)),
                    "conf": float(cf[i]),
                    "quad": quads[i].astype(float).tolist(),
                })
            return dets

        # XYXY fallback
        if getattr(r, "boxes", None) is not None and r.boxes is not None:
            xyxy = r.boxes.xyxy
            cls = r.boxes.cls
            cf = r.boxes.conf

            if isinstance(xyxy, torch.Tensor):
                xyxy = xyxy.detach().cpu().numpy()
            if isinstance(cls, torch.Tensor):
                cls = cls.detach().cpu().numpy()
            if isinstance(cf, torch.Tensor):
                cf = cf.detach().cpu().numpy()

            for i in range(len(xyxy)):
                x1, y1, x2, y2 = [float(x) for x in xyxy[i]]
                cid = int(cls[i])
                quad = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
                dets.append({
                    "class_id": cid,
                    "class_name": self.class_names.get(cid, str(cid)),
                    "conf": float(cf[i]),
                    "quad": quad,
                })

        return dets

    # ---------------- LLM OCR request building ----------------

    @staticmethod
    def build_llm_ocr_prompt_bundle(
        *,
        fields_plan: List[Dict[str, Any]],
        locale_hint: str = "fa",
    ) -> Dict[str, Any]:
        """
        Returns a single, self-contained bundle to be sent to your LLM layer.
        You can convert it to OpenAI Responses API (or any other provider) on your server.

        fields_plan: list of:
          {
            "key": str,
            "kind": "text|numeric|date|mrz",
            "candidates": [{"idx": int, "det_conf": float, "image_data_url": str}, ...]
          }

        Expected LLM response:
          {
            "fields": {
              "<key>": {
                 "candidate_idx": 0,
                 "text": "...",          # what you read
                 "confidence": 0.0-1.0,  # optional but recommended
                 "normalized": "..."     # optional; server can re-normalize anyway
              }, ...
            }
          }
        """
        schema = {
            "schema_version": 1,
            "task": "kyc_field_ocr",
            "locale_hint": locale_hint,
            "fields": fields_plan,
            "output_contract": {
                "type": "json",
                "top_level_keys": ["fields"],
                "field_object": {
                    "candidate_idx": "int (required)",
                    "text": "string|null (required)",
                    "confidence": "float 0..1 (optional)",
                    "normalized": "string|null (optional)"
                }
            }
        }
        return schema

    @staticmethod
    def openai_content_items_from_bundle(bundle: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Helper: convert bundle into OpenAI Responses API `input` content items in ONE request.
        (You can do the same on your server; provided for clarity.)

        Output is a list of items like:
          {"type":"input_text","text":"..."}
          {"type":"input_image","image_url":"data:image/jpeg;base64,..."}
        """
        fields = bundle.get("fields") or []
        # System-style instruction inside the user prompt (keeps it portable).
        intro = (
            "You are an OCR engine for KYC field crops.\n"
            "Read the text in each image crop and return ONLY valid JSON.\n"
            "Rules:\n"
            "- Do NOT hallucinate. If unreadable: text=null and confidence<=0.3.\n"
            "- For numeric fields: output ASCII digits only (0-9), no spaces.\n"
            "- For dates: output as seen; if you can normalize to YYYY-MM-DD, also provide normalized.\n"
            "- For Persian/Arabic text: keep original letters; fix obvious Arabic->Persian variants if confident.\n"
            "- Choose candidate_idx that best matches the intended field.\n"
            "Return JSON with structure:\n"
            "{ \"fields\": { \"<key>\": {\"candidate_idx\":0, \"text\":..., \"confidence\":0.0, \"normalized\":...}, ... } }\n"
        )

        items: List[Dict[str, Any]] = [{"type": "input_text", "text": intro}]
        # Keep a deterministic ordering for mapping images.
        for f in fields:
            key = f.get("key")
            kind = f.get("kind")
            cands = f.get("candidates") or []
            items.append({"type": "input_text", "text": f"\nFIELD key={key} kind={kind} candidates={len(cands)}"})
            for c in cands:
                idx = c.get("idx")
                det_conf = c.get("det_conf")
                items.append({"type": "input_text", "text": f"Candidate idx={idx} det_conf={det_conf}"})
                items.append({"type": "input_image", "image_url": c.get("image_data_url")})
        items.append({"type": "input_text", "text": "\nReturn ONLY JSON. No markdown."})
        return items

    def extract_with_llm(
        self,
        doc_image_path: str,
        doc_cfg: DocConfig,
        dets: List[Dict[str, Any]],
        llm_ocr: Callable[[Dict[str, Any]], Dict[str, Any]],
        user_input: Optional[Dict[str, str]] = None,
        debug: bool = False,
        jpeg_quality: int = 92,
        max_dim: int = 1200,
    ) -> Dict[str, Any]:
        """
        1) Uses YOLO detections + rules to crop candidates.
        2) Builds ONE bundle and calls llm_ocr(bundle).
        3) Maps response back to per-field extraction pack (same shape as before).
        """
        user_input = user_input or {}

        img = cv2.imread(doc_image_path)
        if img is None:
            raise RuntimeError(f"Cannot read image: {doc_image_path}")

        by_class: Dict[str, List[Dict[str, Any]]] = {}
        for d in dets:
            by_class.setdefault(d["class_name"], []).append(d)
        for k in by_class:
            by_class[k].sort(key=lambda x: -float(x.get("conf", 0.0)))

        # Keep candidate structures so we can map LLM's candidate_idx back to quad.
        candidates_for_key: Dict[str, List[Dict[str, Any]]] = {}

        def top_k(rule: FieldRule) -> int:
            k = int(rule.max_candidates or 1)
            if k <= 0:
                k = 1
            return min(k, 6)

        # ---- Build fields_plan for LLM (only OCR kinds) ----
        fields_plan: List[Dict[str, Any]] = []
        internals: Dict[str, Any] = {"quads": {}, "doc_image_shape": list(img.shape[:2])}
        fields: Dict[str, Any] = {}

        for rule in doc_cfg.rules:
            # collect detections for any matching class_name
            candidates: List[Dict[str, Any]] = []
            for cn in rule.class_names:
                candidates.extend(by_class.get(cn, []))
            if not candidates:
                continue

            candidates = [c for c in candidates if float(c.get("conf", 0.0)) >= float(rule.min_det_conf)]
            candidates.sort(key=lambda x: -float(x.get("conf", 0.0)))
            candidates = candidates[:top_k(rule)]
            if not candidates:
                continue

            kind = (rule.ocr_kind or "text").lower()

            # For "none" we don't send to LLM
            if kind == "none":
                best = candidates[0]
                fields[rule.key] = {
                    "class_name": rule.class_names[0],
                    "det_conf": float(best["conf"]),
                    "ocr_conf": None,
                    "value_raw": None,
                    "value_norm": None,
                    "ocr_method": None,
                    "quad": best["quad"],
                    "candidate_idx": 0,
                }
                internals["quads"][rule.key] = best["quad"]
                continue

            # OCR fields: crop + encode candidates
            plan_entry = {"key": rule.key, "kind": kind, "candidates": []}
            candidates_for_key[rule.key] = []

            for idx, det in enumerate(candidates):
                quad = np.array(det["quad"], dtype=np.float32)
                crop = warp_quad_to_rect(img, quad)
                data_url = _encode_jpeg_data_url(crop, jpeg_quality=jpeg_quality, max_dim=max_dim)
                plan_entry["candidates"].append({
                    "idx": idx,
                    "det_conf": float(det["conf"]),
                    "image_data_url": data_url,
                })
                candidates_for_key[rule.key].append(det)

            fields_plan.append(plan_entry)

        # ---- ONE request to LLM ----
        bundle = self.build_llm_ocr_prompt_bundle(fields_plan=fields_plan, locale_hint="fa")
        llm_resp = llm_ocr(bundle) or {}
        llm_fields = llm_resp.get("fields") if isinstance(llm_resp, dict) else None
        if not isinstance(llm_fields, dict):
            llm_fields = {}

        # ---- Map results back into engine output ----
        def normalize_value(rule: FieldRule, kind: str, value: str) -> str:
            if kind == "numeric":
                return normalize_digits_only(value)
            if kind == "date":
                return collapse_spaces(to_en_digits(value))
            if rule.key in ("first_name", "last_name", "full_name"):
                return normalize_name(value)
            if rule.key in ("id_number", "passport_no"):
                return normalize_digits_only(value) if kind == "numeric" else normalize_id_alphanum(value)
            return normalize_fa_text(value)

        def heuristic_conf(kind: str, text: Optional[str], expected_len: int = 0) -> float:
            t = (text or "").strip()
            if not t:
                return 0.0
            if kind == "numeric":
                d = normalize_digits_only(t)
                if expected_len and len(d) == expected_len:
                    return 0.9
                return 0.65
            if kind == "date":
                if parse_date_loose(t):
                    return 0.85
                return 0.55
            if kind == "mrz":
                if "<" in t and len(t.splitlines()) >= 2:
                    return 0.85
                return 0.55
            return 0.70

        for rule in doc_cfg.rules:
            kind = (rule.ocr_kind or "text").lower()
            if kind == "none":
                continue
            # was it planned?
            if rule.key not in candidates_for_key:
                continue

            resp = llm_fields.get(rule.key) or {}
            if not isinstance(resp, dict):
                resp = {}

            cand_idx = resp.get("candidate_idx")
            try:
                cand_idx = int(cand_idx)
            except Exception:
                cand_idx = 0

            det_list = candidates_for_key.get(rule.key) or []
            if not det_list:
                continue
            if cand_idx < 0 or cand_idx >= len(det_list):
                cand_idx = 0

            det = det_list[cand_idx]
            det_conf = float(det.get("conf", 0.0))

            raw_text = resp.get("text")
            if raw_text is None:
                raw_text = resp.get("value")
            if raw_text is None:
                raw_text = ""
            raw_text = str(raw_text)

            conf = resp.get("confidence")
            if conf is None:
                conf = heuristic_conf(kind, raw_text, expected_len=int(rule.expected_len or rule.constraints.get("length") or 0))
            try:
                conf = float(conf)
            except Exception:
                conf = heuristic_conf(kind, raw_text, expected_len=int(rule.expected_len or rule.constraints.get("length") or 0))
            conf = self._clamp01(conf)

            norm_text = resp.get("normalized")
            if norm_text is None or str(norm_text).strip() == "":
                norm_text = normalize_value(rule, kind, raw_text)
            else:
                norm_text = normalize_value(rule, kind, str(norm_text))

            fields[rule.key] = {
                "class_name": rule.class_names[0],
                "det_conf": det_conf,
                "ocr_conf": conf,
                "value_raw": raw_text.strip(),
                "value_norm": norm_text,
                "ocr_method": "llm",
                "quad": det["quad"],
                "candidate_idx": cand_idx,
            }
            internals["quads"][rule.key] = det["quad"]

        out = {"fields": fields, "internals": internals}
        if debug:
            out["detections"] = dets
            out["llm_bundle"] = bundle
            out["llm_response_raw"] = llm_resp
        return out

    # ---------------- Scoring ----------------

    def score(
        self,
        doc_cfg: DocConfig,
        extracted_fields: Dict[str, Any],
        user_input: Optional[Dict[str, str]],
        face_pack: Optional[Dict[str, Any]],
        debug: bool
    ) -> Dict[str, Any]:
        user_input = user_input or {}

        required_rules = [r for r in doc_cfg.rules if r.required]

        found_required_det = 0
        det_ok_keys: set = set()

        extraction_pairs: List[Tuple[float, float]] = []
        match_pairs_all: List[Tuple[float, float]] = []
        match_pairs_core: List[Tuple[float, float]] = []

        per_field: Dict[str, Any] = {}
        mismatch_flags: List[str] = []
        invalid_flags: List[str] = []

        def validate_constraints(rule: FieldRule, kind: str, raw: Optional[str], norm: Optional[str]) -> Tuple[bool, Optional[str]]:
            c = rule.constraints or {}
            if not c:
                return True, None

            raw_s = raw or ""
            norm_s = norm or ""

            length = c.get("length")
            min_length = c.get("min_length")
            max_length = c.get("max_length")

            if length is not None:
                try:
                    L = int(length)
                    if len(norm_s) != L:
                        return False, f"length_expected_{L}"
                except Exception:
                    pass
            if min_length is not None:
                try:
                    L = int(min_length)
                    if len(norm_s) < L:
                        return False, f"min_length_{L}"
                except Exception:
                    pass
            if max_length is not None:
                try:
                    L = int(max_length)
                    if len(norm_s) > L:
                        return False, f"max_length_{L}"
                except Exception:
                    pass

            prefix = c.get("prefix")
            if prefix:
                pref = str(prefix)
                if not str(norm_s).startswith(pref):
                    return False, f"prefix_{pref}"

            regex = c.get("regex")
            if regex:
                try:
                    if re.fullmatch(str(regex), str(norm_s)) is None:
                        return False, "regex_mismatch"
                except re.error:
                    pass

            if kind == "date" and c.get("must_parse", False):
                if parse_date_loose(raw_s) is None:
                    return False, "date_unparsed"

            return True, None

        def wavg(pairs: List[Tuple[float, float]]) -> float:
            if not pairs:
                return 0.0
            sw = sum(w for _, w in pairs)
            return float(sum(v * w for v, w in pairs) / sw) if sw > 0 else 0.0

        for rule in doc_cfg.rules:
            f = extracted_fields.get(rule.key)
            if not f:
                per_field[rule.key] = {
                    "present": False, "det_ok": False, "ocr_ok": False, "extraction": 0.0,
                    "match": None, "reason": "missing"
                }
                continue

            det_conf = self._clamp01(float(f.get("det_conf", 0.0)))
            det_ok = det_conf >= float(rule.min_det_conf)
            if det_ok:
                det_ok_keys.add(rule.key)

            kind = (rule.ocr_kind or "text").lower()

            if kind == "none":
                ocr_ok = True
                extraction = math.sqrt(det_conf * 1.0)
                ocr_conf = None
                raw_value = None
                norm_value = None
            else:
                ocr_conf = self._clamp01(float(f.get("ocr_conf", 0.0) or 0.0))
                raw_value = f.get("value_raw") or ""
                norm_value = f.get("value_norm") or ""
                ocr_ok = (ocr_conf >= float(rule.min_ocr_conf)) and bool(str(norm_value).strip())
                extraction = math.sqrt(det_conf * max(0.01, ocr_conf))

            if rule.required and det_ok:
                found_required_det += 1

            valid_ok = True
            valid_reason = None
            if rule.validator == "iran_national_code" and kind != "none":
                if not iran_national_code_is_valid(norm_value or ""):
                    valid_ok = False
                    valid_reason = "iran_national_code_invalid"

            c_ok, c_reason = validate_constraints(rule, kind, raw_value, norm_value)
            if not c_ok:
                valid_ok = False
                valid_reason = valid_reason or c_reason

            if not valid_ok:
                invalid_flags.append(f"invalid:{rule.key}:{valid_reason}")

            extraction_pairs.append((float(extraction), float(rule.weight)))

            expected = self._get_expected(rule, user_input)
            mscore: Optional[float] = None
            mreason = "no_user_input"

            if expected and kind != "none":
                mt = (rule.match_type or "optional").lower()

                if mt == "exact":
                    got_norm = norm_value or ""
                    exp_norm = expected

                    if kind == "numeric":
                        got_norm = normalize_digits_only(got_norm)
                        exp_norm = normalize_digits_only(exp_norm)
                    elif rule.key in ("first_name", "last_name", "full_name"):
                        got_norm = normalize_name(got_norm)
                        exp_norm = normalize_name(exp_norm)
                    else:
                        got_norm = normalize_fa_text(got_norm)
                        exp_norm = normalize_fa_text(exp_norm)

                    mscore = 1.0 if got_norm == exp_norm else 0.0
                    mreason = "exact_match" if mscore == 1.0 else "exact_mismatch"

                elif mt == "fuzzy":
                    mscore = float(token_set_similarity(raw_value or "", expected))
                    mreason = "fuzzy" if mscore >= float(rule.match_threshold) else "fuzzy_below_threshold"

                elif mt == "date":
                    got = parse_date_loose(raw_value or "")
                    exp = parse_date_loose(expected)
                    if got and exp:
                        mscore = 1.0 if got == exp else 0.0
                        mreason = "date_match" if mscore == 1.0 else "date_mismatch"
                    else:
                        mscore = None
                        mreason = "date_unparsed"

            if mscore is not None:
                match_pairs_all.append((float(mscore), float(rule.weight)))
                if rule.match_gate:
                    match_pairs_core.append((float(mscore), float(rule.weight)))

            if rule.must_match:
                if mscore == 0.0 and mreason in ("exact_mismatch", "date_mismatch"):
                    mismatch_flags.append(f"mismatch:{rule.key}")
                if mreason == "fuzzy_below_threshold":
                    mismatch_flags.append(f"mismatch:{rule.key}")

            rec: Dict[str, Any] = {
                "present": det_ok,
                "det_ok": det_ok,
                "ocr_ok": ocr_ok,
                "det_conf": det_conf,
                "ocr_conf": ocr_conf,
                "extraction": float(extraction),
                "match": mscore,
                "reason": mreason,
                "valid": valid_ok,
                "valid_reason": valid_reason,
            }
            if debug:
                rec.update({
                    "raw": raw_value,
                    "normalized": norm_value,
                    "class_name": f.get("class_name"),
                    "ocr_method": f.get("ocr_method"),
                    "quad": f.get("quad"),
                    "candidate_idx": f.get("candidate_idx"),
                })
            per_field[rule.key] = rec

        coverage = (found_required_det / max(1, len(required_rules))) if required_rules else 1.0
        extraction_score = wavg(extraction_pairs)
        match_all = (wavg(match_pairs_all) if match_pairs_all else None)
        match_core = (wavg(match_pairs_core) if match_pairs_core else None)

        count_gate_ok = True
        if doc_cfg.min_detected_fields_count is not None:
            count_gate_ok = (len(det_ok_keys) >= int(doc_cfg.min_detected_fields_count))

        face_val = None
        face_reason = None
        if face_pack is not None:
            face_reason = face_pack.get("reason")
            if doc_cfg.face_metric == "cosine":
                face_val = face_pack.get("cosine")
            else:
                face_val = face_pack.get("score01")

        base_match = float(match_core if match_core is not None else (match_all or 0.0))
        doc_score = self._clamp01(0.5 * float(extraction_score) + 0.5 * base_match)

        # Face is ON/OFF only:
        final_score = doc_score
        if doc_cfg.require_face_match:
            if face_val is not None:
                final_score = self._clamp01(0.65 * doc_score + 0.35 * self._clamp01(float(face_val)))
            else:
                final_score = doc_score

        decision, reasons = decide(
            doc_cfg=doc_cfg,
            coverage=coverage,
            extraction=extraction_score,
            match_all=match_all,
            match_core=match_core,
            count_gate_ok=count_gate_ok,
            detected_count=len(det_ok_keys),
            mismatch_flags=mismatch_flags,
            invalid_flags=invalid_flags,
            face_value=face_val,
            face_reason=face_reason,
        )

        return {
            "decision": decision,
            "reasons": reasons,
            "scores": {
                "coverage": coverage,
                "extraction": extraction_score,
                "match_all": match_all,
                "match_core": match_core,
                "detected_fields_count": len(det_ok_keys),
                "count_gate_ok": count_gate_ok,
                "face": face_val,
                "face_metric": doc_cfg.face_metric,
                "doc_score": doc_score,
                "final_score": final_score,
            },
            "per_field": per_field,
        }


# ============================================================
# 8) Decision logic (no optional selfie rescue)
# ============================================================

def decide(
    doc_cfg: DocConfig,
    coverage: float,
    extraction: float,
    match_all: Optional[float],
    match_core: Optional[float],
    count_gate_ok: bool,
    detected_count: int,
    mismatch_flags: List[str],
    invalid_flags: List[str],
    face_value: Optional[float],
    face_reason: Optional[str],
) -> Tuple[str, List[str]]:
    reasons: List[str] = []

    if doc_cfg.min_detected_fields_count is not None and not count_gate_ok:
        reasons.append("detected_fields_count_too_low")
        reasons.append(f"detected={detected_count}_min={doc_cfg.min_detected_fields_count}")
        return "REJECT", reasons

    if coverage < doc_cfg.reject_below_coverage:
        reasons.append("coverage_too_low")
        return "REJECT", reasons

    # Invalid / mismatch => do NOT approve
    if invalid_flags:
        reasons.extend(invalid_flags[:12])
        return "REVIEW", reasons

    if mismatch_flags:
        reasons.extend(mismatch_flags[:12])
        return "REVIEW", reasons

    # Face policy ON/OFF
    if doc_cfg.require_face_match:
        if face_value is None:
            reasons.append(f"face_required_but_unavailable:{face_reason or 'unknown'}")
            return "REVIEW", reasons
        if float(face_value) < float(doc_cfg.face_match_threshold):
            reasons.append("face_below_threshold")
            reasons.append(f"face={face_value:.3f}_thr={doc_cfg.face_match_threshold:.3f}")
            return "REVIEW", reasons

    # Approve path
    if coverage >= doc_cfg.approve_min_coverage and extraction >= doc_cfg.approve_min_extraction:
        if match_core is None:
            if extraction >= (doc_cfg.approve_min_extraction + doc_cfg.approve_no_input_extra_buffer):
                return "APPROVE", reasons
            reasons.append("no_core_user_input")
            return "REVIEW", reasons

        if match_core >= doc_cfg.approve_min_match_core:
            return "APPROVE", reasons

        reasons.append("core_match_below_threshold")
        if match_all is not None:
            reasons.append(f"match_all={match_all:.3f}")
        return "REVIEW", reasons

    # Review / reject fallback
    if coverage >= doc_cfg.review_min_coverage:
        if extraction < doc_cfg.approve_min_extraction:
            reasons.append("extraction_low")
        if coverage < doc_cfg.approve_min_coverage:
            reasons.append("partial_coverage")
        return "REVIEW", reasons

    reasons.append("coverage_insufficient")
    return "REVIEW", reasons


# ============================================================
# 9) Swap utilities (same as before)
# ============================================================

def swap_field_keys(fields: Dict[str, Any], internals: Dict[str, Any], a: str, b: str, reason: str) -> None:
    if a in fields and b in fields:
        fields[a], fields[b] = fields[b], fields[a]
    q = internals.get("quads", {})
    if a in q and b in q:
        q[a], q[b] = q[b], q[a]
    internals.setdefault("swap_events", []).append({"a": a, "b": b, "reason": reason})


def apply_config_swaps(fields: Dict[str, Any], internals: Dict[str, Any], cfg: DocConfig) -> None:
    for a, b in cfg.swap_pairs:
        swap_field_keys(fields, internals, a, b, reason="config_swap_pairs")


def best_name_assignment(
    ocr_first_raw: str,
    ocr_last_raw: str,
    exp_first: str,
    exp_last: str,
    margin: float = 0.06
) -> Dict[str, Any]:
    s11 = token_set_similarity(ocr_first_raw or "", exp_first or "")
    s22 = token_set_similarity(ocr_last_raw or "", exp_last or "")
    normal = (s11 + s22) / 2.0

    s12 = token_set_similarity(ocr_first_raw or "", exp_last or "")
    s21 = token_set_similarity(ocr_last_raw or "", exp_first or "")
    swapped = (s12 + s21) / 2.0

    if swapped > normal + margin:
        return {"swap": True, "score_normal": normal, "score_swapped": swapped, "details": {"first->last": s12, "last->first": s21}}
    return {"swap": False, "score_normal": normal, "score_swapped": swapped, "details": {"first->first": s11, "last->last": s22}}


def apply_name_swap_if_needed(fields: Dict[str, Any], internals: Dict[str, Any], cfg: DocConfig, user_input: Dict[str, str]) -> None:
    if not cfg.enable_name_swap:
        return
    if "first_name" not in fields or "last_name" not in fields:
        return
    exp_first = user_input.get("first_name") or ""
    exp_last = user_input.get("last_name") or ""
    if not exp_first or not exp_last:
        return

    f_raw = fields["first_name"].get("value_raw") or ""
    l_raw = fields["last_name"].get("value_raw") or ""

    info = best_name_assignment(f_raw, l_raw, exp_first, exp_last, margin=cfg.name_swap_margin)
    if info.get("swap"):
        swap_field_keys(fields, internals, "first_name", "last_name", reason="heuristic_name_swap")
        internals.setdefault("swap_events", []).append({"a": "first_name", "b": "last_name", "reason": "heuristic_name_swap", "meta": info})


# ============================================================
# 10) Public API
# ============================================================

def run_kyc(
    *,
    model_path: str,
    doc_image_path: str,
    doc_id: str,
    doc_config_payload: Optional[Dict[str, Any]] = None,
    doc_config: Optional[DocConfig] = None,

    # LLM OCR hook (REQUIRED)
    llm_ocr: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,

    user_input: Optional[Dict[str, str]] = None,
    selfie_image_path: Optional[str] = None,

    yolo_conf: float = 0.25,
    yolo_iou: float = 0.6,
    device: str = "0",
    max_det: int = 200,

    cache_engine: bool = True,
    debug: bool = False,
) -> Dict[str, Any]:
    if doc_config_payload is not None:
        cfg = doc_config_from_payload(doc_config_payload)
    elif doc_config is not None:
        cfg = doc_config
    else:
        raise ValueError("doc_config_payload or doc_config must be provided (server-driven).")

    if llm_ocr is None:
        raise ValueError("llm_ocr callback is required in LLM-OCR edition.")

    engine = get_engine(model_path) if cache_engine else KYCEngine(model_path)

    dets = engine.detect(doc_image_path, conf=yolo_conf, iou=yolo_iou, device=device, max_det=max_det)

    pack = engine.extract_with_llm(
        doc_image_path,
        cfg,
        dets,
        llm_ocr=llm_ocr,
        user_input=user_input,
        debug=debug
    )
    fields = pack["fields"]
    internals = pack["internals"]

    apply_config_swaps(fields, internals, cfg)
    apply_name_swap_if_needed(fields, internals, cfg, user_input or {})

    # ---------------- Face (ON/OFF) ----------------
    face_pack = None
    if cfg.require_face_match:
        if selfie_image_path:
            quad = internals["quads"].get("doc_photo") or internals["quads"].get("photo")
            face_pack = compute_face_pack(doc_image_path=doc_image_path, selfie_image_path=selfie_image_path, doc_photo_quad=quad)
        else:
            face_pack = {"score01": None, "cosine": None, "reason": "selfie_missing", "details": None}

    scoring = engine.score(cfg, fields, user_input=user_input, face_pack=face_pack, debug=debug)

    result: Dict[str, Any] = {
        "doc_id": doc_id,
        "decision": scoring["decision"],
        "scores": scoring["scores"],
        "reasons": scoring["reasons"],
        "per_field": scoring["per_field"],
    }

    # next_steps / selfie_mode
    selfie_mode = "required" if cfg.require_face_match else "disabled"
    next_steps: List[str] = []
    if cfg.require_face_match and not selfie_image_path:
        next_steps.append("upload_selfie_required")
    if result["decision"] == "REJECT":
        next_steps.append("retake_document_photo")

    result["selfie_mode"] = selfie_mode
    result["next_steps"] = next_steps

    if debug:
        result["detections"] = dets
        result["fields"] = fields
        result["internals"] = internals
        result["face"] = face_pack
        result["llm_bundle"] = pack.get("llm_bundle")
        result["llm_response_raw"] = pack.get("llm_response_raw")

    return result
