
# kyc_policy_engine.py (LLM OCR edition)
# ------------------------------------------------------------
# Trendo KYC: YOLO(OBB) -> crop fields -> ONE LLM request for OCR -> rule-based validation + (on/off) face match
#
# IMPORTANT:
# - This file intentionally REMOVES EasyOCR and all "manual OCR" logic.
# - Implement `llm_ocr(request_bundle) -> response_bundle` on your server and inject into run_kyc/run_kyc_auto_variant.
#   If `llm_ocr` is not provided, built-in Eden AI OCR callback is used.
#
from __future__ import annotations

import os
import re
import math
import json
import base64
import logging
import threading
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional, Tuple, Iterable, Callable

import cv2
import numpy as np
import torch
from ultralytics import YOLO

# ============================================================
# 1) Light normalization (ONLY for server-side matching/validation)
#    - OCR is handled by LLM, but we keep robust string normalization
#      so matching does not break on minor Arabic/Persian variants.
# ============================================================

_FA_AR_TO_EN_DIGITS = str.maketrans({
    "۰": "0", "۱": "1", "۲": "2", "۳": "3", "۴": "4",
    "۵": "5", "۶": "6", "۷": "7", "۸": "8", "۹": "9",
    "٠": "0", "١": "1", "٢": "2", "٣": "3", "٤": "4",
    "٥": "5", "٦": "6", "٧": "7", "٨": "8", "٩": "9",
})

_AR_FA_LETTERS = str.maketrans({
    "ي": "ی",
    "ى": "ی",
    "ك": "ک",
    "ة": "ه",
    "ۀ": "ه",
    "ؤ": "و",
    "إ": "ا",
    "أ": "ا",
    "ٱ": "ا",
    "ء": "",
    "ئ": "ی",
    "ﻻ": "لا",
    "ـ": "",   # tatweel
})

_DIACRITICS_RE = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
_ZWNJ = "\u200c"

_OCR_MAX_DIM = int(os.getenv("KYC_OCR_MAX_DIM", "2200"))
_OCR_JPEG_QUALITY = int(os.getenv("KYC_OCR_JPEG_QUALITY", "98"))
_EDENAI_URL = os.getenv("KYC_EDENAI_URL", "https://api.edenai.run/v3/universal-ai/")
_EDENAI_MODEL = "ocr/ocr/google"
_EDENAI_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiNTc1MDVlYTgtOGI3Yy00ZDQyLTlhYmMtOGMxYzliOGEyZDRjIiwidHlwZSI6ImFwaV90b2tlbiJ9.PLCBmy1gNDTpqfzCJwfQlaIVY4EkcE0TL7lsXiOHTMY"
_EDENAI_TIMEOUT_SEC = float(os.getenv("KYC_EDENAI_TIMEOUT_SEC", "40"))
_EDENAI_SHOW_ORIGINAL_RESPONSE = os.getenv("KYC_EDENAI_SHOW_ORIGINAL_RESPONSE", "0").strip().lower() in (
    "1", "true", "yes", "on"
)
_LOG = logging.getLogger(__name__)


def collapse_spaces(s: str) -> str:
    return re.sub(r"\s+", " ", (s or "").strip())


def to_en_digits(s: str) -> str:
    return (s or "").translate(_FA_AR_TO_EN_DIGITS)


def normalize_fa_text(s: str) -> str:
    s = (s or "")
    s = to_en_digits(s)
    s = s.translate(_AR_FA_LETTERS)
    s = s.replace(_ZWNJ, " ")
    s = _DIACRITICS_RE.sub("", s)
    return collapse_spaces(s)


def normalize_name(s: str) -> str:
    s = normalize_fa_text(s)
    s = re.sub(r"[^\w\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\s\-]", "", s, flags=re.UNICODE)
    return collapse_spaces(s).lower()


def normalize_digits_only(s: str) -> str:
    s = to_en_digits(s)
    return re.sub(r"[^0-9]", "", s)


def normalize_id_alphanum(s: str) -> str:
    s = to_en_digits(s)
    return re.sub(r"[^0-9A-Za-z]", "", s).upper()


# ============================================================
# 2) Similarity + loose date parsing (for matching)
# ============================================================

def levenshtein_ratio(a: str, b: str) -> float:
    a = a or ""
    b = b or ""
    if a == b:
        return 1.0
    if not a or not b:
        return 0.0

    prev = list(range(len(b) + 1))
    for i, ca in enumerate(a, start=1):
        cur = [i]
        for j, cb in enumerate(b, start=1):
            cost = 0 if ca == cb else 1
            cur.append(min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost))
        prev = cur

    dist = prev[-1]
    return 1.0 - (dist / max(len(a), len(b)))


def token_set_similarity(a: str, b: str) -> float:
    a = normalize_name(a)
    b = normalize_name(b)
    if not a or not b:
        return 0.0
    sa = set(a.split())
    sb = set(b.split())
    inter = len(sa & sb)
    union = len(sa | sb)
    j = inter / union if union else 0.0
    return 0.6 * j + 0.4 * levenshtein_ratio(a, b)


def parse_date_loose(s: str) -> Optional[Tuple[int, int, int]]:
    """
    Accepts common formats and returns (YYYY, M, D) if parseable.
    This is only for validation/matching, not for OCR.
    """
    s = to_en_digits(s or "")
    s = s.replace("-", "/").replace(".", "/")
    nums = re.findall(r"\d{1,4}", s)
    if len(nums) < 3:
        return None

    vals = [int(x) for x in nums[:3]]
    perms = [
        (vals[0], vals[1], vals[2]),
        (vals[0], vals[2], vals[1]),
        (vals[1], vals[0], vals[2]),
        (vals[1], vals[2], vals[0]),
        (vals[2], vals[0], vals[1]),
        (vals[2], vals[1], vals[0]),
    ]
    for y, m, d in perms:
        if 1000 <= y <= 2200 and 1 <= m <= 12 and 1 <= d <= 31:
            return (y, m, d)
    return None


# ============================================================
# 3) Iran National Code checksum (validator only)
# ============================================================

def iran_national_code_is_valid(code: str) -> bool:
    code = normalize_digits_only(code)
    if len(code) != 10:
        return False
    if len(set(code)) == 1:
        return False

    digits = [int(c) for c in code]
    s = sum(digits[i] * (10 - i) for i in range(9))
    r = s % 11
    check = digits[9]
    if r < 2:
        return check == r
    return check == (11 - r)

# ============================================================
# 3b) Generic validators registry (country-agnostic)
# ============================================================

def validate_id_nik(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    """
    Indonesia NIK (KTP/e-KTP) validation:
    - Must be 16 digits
    - DOB segment: positions 7-12 (DDMMYY) where DD+40 indicates female
    - No checksum (structural validation only)
    """
    v = normalize_digits_only(value_norm)

    if len(v) != 16:
        return _v_fail("nik_length")

    # Optional stricter checks (OFF by default to avoid false rejects)
    # You can enable from config/identity_rules constraints if you want.
    if constraints.get("reject_region_000000", False) and v[:6] == "000000":
        return _v_fail("nik_region_all_zero")
    if constraints.get("reject_serial_0000", False) and v[12:16] == "0000":
        return _v_fail("nik_serial_all_zero")

    # DOB segment: DDMMYY
    dob = v[6:12]
    try:
        dd = int(dob[0:2])
        mm = int(dob[2:4])
        # yy exists but we don't need century for structural validation
        _yy = int(dob[4:6])
    except Exception:
        return _v_fail("nik_dob_parse")

    # Female encoding: day + 40
    if 41 <= dd <= 71:
        dd_real = dd - 40
    elif 1 <= dd <= 31:
        dd_real = dd
    else:
        return _v_fail("nik_day_range")

    if not (1 <= mm <= 12):
        return _v_fail("nik_month_range")

    # Day-of-month sanity (allow Feb 29 to avoid false rejects due to unknown century)
    month_days = {1: 31, 2: 29, 3: 31, 4: 30, 5: 31, 6: 30, 7: 31, 8: 31, 9: 30, 10: 31, 11: 30, 12: 31}
    if not (1 <= dd_real <= month_days[mm]):
        return _v_fail("nik_day_invalid")

    return _v_ok()

def default_locale_hint(country: str) -> str:
    """
    Conservative OCR locale hint. Keep this deterministic; override via policy when you know better.
    """
    c = (country or "").strip().upper()
    if c == "IR":
        return "fa"
    return "en"


def _get_edenai_api_key() -> str:
    key = _EDENAI_API_KEY.strip()
    if not key:
        raise ValueError("Missing hardcoded Eden AI API key.")
    return key


def _edenai_language_from_locale(locale_hint: str) -> str:
    lh = (locale_hint or "").strip().lower()
    if not lh:
        return "en"
    base = lh.split("-", 1)[0].split("_", 1)[0]
    if not base:
        return "en"
    return base


def _deep_values_for_keys(data: Any, keys: Iterable[str]) -> List[Any]:
    keyset = {str(k).lower() for k in keys}
    out: List[Any] = []

    def walk(node: Any) -> None:
        if isinstance(node, dict):
            for k, v in node.items():
                if str(k).lower() in keyset:
                    out.append(v)
                walk(v)
        elif isinstance(node, list):
            for x in node:
                walk(x)

    walk(data)
    return out


def _as_nonempty_text(v: Any) -> Optional[str]:
    if isinstance(v, str):
        s = collapse_spaces(v)
        return s or None
    return None


def _to_conf01(v: Any) -> Optional[float]:
    try:
        c = float(v)
    except Exception:
        return None
    if c > 1.0 and c <= 100.0:
        c = c / 100.0
    if c < 0.0:
        c = 0.0
    if c > 1.0:
        c = 1.0
    return c


def _extract_eden_text_conf(output: Any) -> Tuple[Optional[str], Optional[float]]:
    text_keys = ("text", "full_text", "raw_text", "ocr_text", "content", "transcription", "value")
    conf_keys = ("confidence", "score", "ocr_confidence", "global_score", "probability")

    best_text: Optional[str] = None
    for vv in _deep_values_for_keys(output, text_keys):
        t = _as_nonempty_text(vv)
        if t and (best_text is None or len(t) > len(best_text)):
            best_text = t

    confs: List[float] = []
    for vv in _deep_values_for_keys(output, conf_keys):
        c = _to_conf01(vv)
        if c is not None:
            confs.append(c)
    best_conf = max(confs) if confs else None

    return best_text, best_conf


def _call_edenai_ocr_single(*, image_data_url: str, language: str, api_key: str) -> Tuple[Optional[str], Optional[float]]:
    try:
        import requests  # local import to keep this file usable even if requests is absent
    except Exception as e:
        raise RuntimeError("requests package is required for Eden AI OCR.") from e

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": _EDENAI_MODEL,
        "input": {
            "file": image_data_url,
            "language": language,
        },
        "show_original_response": _EDENAI_SHOW_ORIGINAL_RESPONSE,
    }

    resp = requests.post(_EDENAI_URL, headers=headers, json=payload, timeout=_EDENAI_TIMEOUT_SEC)
    try:
        data = resp.json()
    except Exception as e:
        raise RuntimeError(f"Eden AI response is not JSON (http={resp.status_code}).") from e

    if data.get("status") != "success":
        err = data.get("error") or data.get("message") or f"http_{resp.status_code}"
        raise RuntimeError(f"Eden AI OCR failed: {err}")

    output = data.get("output")
    return _extract_eden_text_conf(output)


def eden_llm_ocr(request_bundle: Dict[str, Any]) -> Dict[str, Any]:
    """
    Adapter for KYC bundle -> Eden AI OCR.
    Returns response in the same shape expected by extract_with_llm().
    """
    api_key = _get_edenai_api_key()
    locale_hint = str(request_bundle.get("locale_hint") or "en")
    language = _edenai_language_from_locale(locale_hint)

    out_fields: Dict[str, Any] = {}
    fields = request_bundle.get("fields") or []

    for field in fields:
        if not isinstance(field, dict):
            continue
        key = str(field.get("key") or "").strip()
        if not key:
            continue

        candidates = field.get("candidates") or []
        best_idx = 0
        best_text: Optional[str] = None
        best_conf: Optional[float] = None
        best_score = -1.0

        for i, cand in enumerate(candidates):
            if not isinstance(cand, dict):
                continue

            image_data_url = cand.get("image_data_url")
            if not isinstance(image_data_url, str) or not image_data_url.strip():
                continue

            try:
                text, conf = _call_edenai_ocr_single(
                    image_data_url=image_data_url,
                    language=language,
                    api_key=api_key,
                )
            except Exception as e:
                _LOG.warning("Eden OCR failed for key=%s candidate=%s: %s", key, i, e)
                continue

            score = 0.0
            if text:
                score += 0.5
            if conf is not None:
                score += conf

            if score > best_score:
                best_score = score
                try:
                    best_idx = int(cand.get("idx"))
                except Exception:
                    best_idx = i
                best_text = text
                best_conf = conf

        out_fields[key] = {
            "candidate_idx": best_idx,
            "text": best_text,
            "confidence": (best_conf if best_conf is not None else 0.0),
            "normalized": None,
        }

    return {"fields": out_fields}


ValidatorFn = Callable[[str, Dict[str, Any]], Tuple[bool, Optional[str]]]


def _parse_int_list(v: Any) -> Optional[List[int]]:
    if v is None:
        return None
    if isinstance(v, list):
        out: List[int] = []
        for x in v:
            try:
                out.append(int(x))
            except Exception:
                return None
        return out
    if isinstance(v, str):
        s = v.strip()
        if not s:
            return None
        # Accept "10,9,8" or "10 9 8" or "[10,9,8]"
        if s.startswith("[") and s.endswith("]"):
            try:
                arr = json.loads(s)
                if isinstance(arr, list):
                    return _parse_int_list(arr)
            except Exception:
                pass
        parts = re.split(r"[\s,;]+", s)
        out = []
        for p in parts:
            if not p:
                continue
            try:
                out.append(int(p))
            except Exception:
                return None
        return out
    return None


def luhn_is_valid(code: str) -> bool:
    """Standard Luhn mod-10 checksum."""
    code = normalize_digits_only(code)
    if len(code) < 2:
        return False

    digits = [int(c) for c in code]
    checksum = 0
    parity = (len(digits) - 2) % 2  # double every other digit (excluding last) from the right

    for i, d in enumerate(digits[:-1]):
        if i % 2 == parity:
            dd = d * 2
            if dd > 9:
                dd -= 9
            checksum += dd
        else:
            checksum += d

    check = digits[-1]
    return ((checksum + check) % 10) == 0


def mod11_weighted_is_valid(
        code: str,
        *,
        weights: List[int],
        mod: int = 11,
        check_index: Optional[int] = None,
        check_rule: str = "mod_minus_r",
) -> bool:
    """
    Generic weighted mod-11 validator.

    Parameters are taken from constraints:
      - weights: list[int] for all digits EXCEPT the check digit.
      - mod: default 11
      - check_index: index of check digit (default last)
      - check_rule:
          * "iran"        : expected = r if r < 2 else (11 - r)
          * "mod_minus_r" : expected = (mod - r) % mod
          * "r"           : expected = r
    """
    code = normalize_digits_only(code)
    if not code:
        return False

    digits = [int(c) for c in code]
    n = len(digits)
    if n < 2:
        return False

    ci = (n - 1) if check_index is None else int(check_index)
    if ci < 0:
        ci = n + ci
    if ci < 0 or ci >= n:
        return False

    check_digit = digits[ci]
    data_digits = [digits[i] for i in range(n) if i != ci]

    if len(weights) != len(data_digits):
        return False

    s = sum(d * w for d, w in zip(data_digits, weights))
    r = s % int(mod)

    rule = (check_rule or "").strip().lower()
    if rule == "iran":
        expected = r if r < 2 else (int(mod) - r)
    elif rule in {"mod_minus_r", "mod-11"}:
        expected = (int(mod) - r) % int(mod)
    elif rule == "r":
        expected = r
    else:
        return False

    return int(check_digit) == int(expected)


def _v_ok() -> Tuple[bool, Optional[str]]:
    return True, None


def _v_fail(reason: str) -> Tuple[bool, Optional[str]]:
    return False, reason


def validate_iran_national_code(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    if iran_national_code_is_valid(value_norm):
        return _v_ok()
    return _v_fail("iran_national_code_invalid")


def validate_luhn(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    if luhn_is_valid(value_norm):
        return _v_ok()
    return _v_fail("luhn_invalid")


def validate_mod11_weighted(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    weights = _parse_int_list(constraints.get("weights") or constraints.get("weight_list"))
    if not weights:
        return _v_fail("mod11_missing_weights")

    mod = constraints.get("mod", 11)
    check_index = constraints.get("check_index")
    check_rule = constraints.get("check_rule", "mod_minus_r")

    try:
        ok = mod11_weighted_is_valid(
            value_norm,
            weights=weights,
            mod=int(mod),
            check_index=(int(check_index) if check_index is not None else None),
            check_rule=str(check_rule),
        )
    except Exception:
        ok = False

    if ok:
        return _v_ok()
    return _v_fail("mod11_invalid")


def validate_none(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    return _v_ok()


def validate_missing_identity_rule(value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    return _v_fail("identity_rule_missing")


_VALIDATOR_REGISTRY: Dict[str, ValidatorFn] = {
    "iran_national_code": validate_iran_national_code,
    "luhn": validate_luhn,
    "mod11_weighted": validate_mod11_weighted,
    "id_nik": validate_id_nik,
    "none": validate_none,
    "skip": validate_none,
    "missing_identity_rule": validate_missing_identity_rule,
}


def run_validator(validator_name: Optional[str], value_norm: str, constraints: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
    """
    Dispatches to a deterministic validator.
    Unknown validator names are treated as invalid (safer than silently skipping).
    """
    name = (validator_name or "").strip().lower()
    if not name:
        return _v_ok()

    fn = _VALIDATOR_REGISTRY.get(name)
    if fn is None:
        return _v_fail(f"unknown_validator:{name}")
    return fn(value_norm or "", constraints or {})



# ============================================================
# 4) OBB crop utilities
# ============================================================
def expand_quad(quad: np.ndarray, scale: float) -> np.ndarray:
    quad = quad.astype(np.float32)
    c = quad.mean(axis=0, keepdims=True)
    return c + (quad - c) * float(scale)

def order_points_clockwise(pts: np.ndarray) -> np.ndarray:
    pts = pts.astype(np.float32)
    s = pts.sum(axis=1)
    tl = pts[np.argmin(s)]
    br = pts[np.argmax(s)]
    d = np.diff(pts, axis=1).reshape(-1)
    tr = pts[np.argmin(d)]
    bl = pts[np.argmax(d)]
    return np.array([tl, tr, br, bl], dtype=np.float32)


def clip_points(pts: np.ndarray, w: int, h: int) -> np.ndarray:
    pts[:, 0] = np.clip(pts[:, 0], 0, w - 1)
    pts[:, 1] = np.clip(pts[:, 1], 0, h - 1)
    return pts


def warp_quad_to_rect(img_bgr: np.ndarray, quad: np.ndarray) -> np.ndarray:
    H, W = img_bgr.shape[:2]
    quad = quad.astype(np.float32)
    quad = clip_points(quad, W, H)
    quad = order_points_clockwise(quad)

    tl, tr, br, bl = quad
    widthA = np.linalg.norm(br - bl)
    widthB = np.linalg.norm(tr - tl)
    maxW = int(max(widthA, widthB))
    heightA = np.linalg.norm(tr - br)
    heightB = np.linalg.norm(tl - bl)
    maxH = int(max(heightA, heightB))

    maxW = max(maxW, 2)
    maxH = max(maxH, 2)

    dst = np.array([[0, 0], [maxW - 1, 0], [maxW - 1, maxH - 1], [0, maxH - 1]], dtype=np.float32)
    M = cv2.getPerspectiveTransform(quad, dst)
    return cv2.warpPerspective(img_bgr, M, (maxW, maxH), flags=cv2.INTER_CUBIC)

def apply_numeric_ocr_filter_cv2(img_bgr: np.ndarray) -> np.ndarray:
    """
    Classic threshold preprocessing (requested):
    - Convert to gray
    - BINARY threshold at 120
    """
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY)

    h, w = thresh.shape[:2]
    pad = max(12, int(round(0.06 * max(h, w))))
    thresh = cv2.copyMakeBorder(thresh, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=255)

    return cv2.cvtColor(thresh, cv2.COLOR_GRAY2BGR)

def apply_smart_ocr_filter_cv2(img_bgr: np.ndarray) -> np.ndarray:
    """
    Apply Grayscale + Contrast + Sharpening for better LLM OCR.
    """
    # 1. Grayscale
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)

    # 2. Contrast Stretching (Increases visibility of numbers)
    # alpha 1.3 to 1.5 is usually good for docs
    alpha = 1.4
    beta = 10
    contrast = cv2.convertScaleAbs(gray, alpha=alpha, beta=beta)

    # 3. Sharpening (Crucial for digits)
    kernel = np.array([[0, -1, 0],
                       [-1, 5, -1],
                       [0, -1, 0]])
    sharp = cv2.filter2D(contrast, -1, kernel)

    # 4. Add Padding (White Border) - OpenAI OCR works better with padding
    h, w = sharp.shape
    pad = 10 # pixels
    padded = cv2.copyMakeBorder(sharp, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=255)

    # Convert back to BGR because imencode might expect 3 channels or for consistency
    return cv2.cvtColor(padded, cv2.COLOR_GRAY2BGR)


def apply_adaptive_ocr_threshold(img_bgr: np.ndarray) -> np.ndarray:
    """Simple binary threshold for card OCR (requested approach)."""
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY)

    h, w = thresh.shape[:2]
    pad = max(12, int(round(0.06 * max(h, w))))
    thresh = cv2.copyMakeBorder(thresh, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=255)
    return cv2.cvtColor(thresh, cv2.COLOR_GRAY2BGR)


def apply_sharp_ocr_filter_cv2(img: np.ndarray) -> np.ndarray:
    """Sharpen OCR input with requested kernel."""
    kernel = np.array([[0, -1, 0],
                      [-1, 5, -1],
                      [0, -1, 0]])
    sharpened = cv2.filter2D(img, -1, kernel)
    return sharpened

def _resize_max_dim(img: np.ndarray, max_dim: int = 1000) -> np.ndarray:
    h, w = img.shape[:2]
    m = max(h, w)
    if m <= max_dim:
        return img
    scale = max_dim / float(m)
    nw = max(2, int(round(w * scale)))
    nh = max(2, int(round(h * scale)))
    return cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)

def _resize_to_range(img: np.ndarray, *, max_dim: int = 1400, min_h: int = 0, min_w: int = 0) -> np.ndarray:
    h, w = img.shape[:2]

    # downscale if too big
    m = max(h, w)
    if m > max_dim:
        s = max_dim / float(m)
        img = cv2.resize(img, (max(2, int(round(w*s))), max(2, int(round(h*s)))), interpolation=cv2.INTER_AREA)
        h, w = img.shape[:2]

    # upscale if too small (OCR critical)
    s_up = 1.0
    if min_h and h < min_h:
        s_up = max(s_up, min_h / float(h))
    if min_w and w < min_w:
        s_up = max(s_up, min_w / float(w))

    if s_up > 1.01:
        img = cv2.resize(img, (int(round(w*s_up)), int(round(h*s_up))), interpolation=cv2.INTER_CUBIC)

    return img

def _normalize_jpeg_quality(value: int, default: int = 98) -> int:
    try:
        q = int(value)
    except Exception:
        return default
    return max(40, min(100, q))

def _encode_image_for_llm(
        img_bgr: np.ndarray,
        *,
        kind: str = "text",
        variant: str = "smart",
        max_dim: int = _OCR_MAX_DIM,
        jpeg_quality: int = _OCR_JPEG_QUALITY,
) -> str:
    kind = (kind or "text").lower()
    variant = (variant or "smart").lower()

    if kind in ("numeric", "date"):
        img_bgr = _resize_to_range(img_bgr, max_dim=max_dim, min_h=180, min_w=600)
    else:
        img_bgr = _resize_to_range(img_bgr, max_dim=max_dim, min_h=120, min_w=0)

    if variant == "raw":
        processed = apply_sharp_ocr_filter_cv2(img_bgr)
    elif variant == "sharpen":
        processed = apply_sharp_ocr_filter_cv2(img_bgr)
    elif variant == "thresh":
        if kind in ("numeric", "date"):
            processed = apply_numeric_ocr_filter_cv2(img_bgr)
        else:
            processed = apply_adaptive_ocr_threshold(img_bgr)
    else:
        processed = apply_smart_ocr_filter_cv2(img_bgr)  # همون فیلتر فعلی‌ات
    q = _normalize_jpeg_quality(jpeg_quality, _OCR_JPEG_QUALITY)
    ok, buf = cv2.imencode(".jpg", processed, [cv2.IMWRITE_JPEG_QUALITY, q])
    if not ok:
        raise RuntimeError("cv2.imencode(.jpg) failed")
    b64 = base64.b64encode(buf.tobytes()).decode("ascii")
    return f"data:image/jpeg;base64,{b64}"

# ============================================================
# 5) Face matcher (ON/OFF) - InsightFace
# ============================================================

class FaceMatcherBase:
    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        raise NotImplementedError()


class FaceMatcherUnavailable(FaceMatcherBase):
    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        return {"score01": None, "cosine": None, "reason": "face_module_unavailable", "details": None}


class InsightFaceMatcher(FaceMatcherBase):
    def __init__(self, det_size: Tuple[int, int] = (640, 640), prefer_gpu: bool = False):
        from insightface.app import FaceAnalysis  # type: ignore
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if prefer_gpu else ["CPUExecutionProvider"]
        self.app = FaceAnalysis(name="buffalo_l",root="/var/www/html/runtime", providers=providers)
        self.app.prepare(ctx_id=-1, det_size=det_size)

    @staticmethod
    def _largest_face(faces):
        if not faces:
            return None

        def area(f):
            x1, y1, x2, y2 = f.bbox
            return float((x2 - x1) * (y2 - y1))

        return max(faces, key=area)

    @staticmethod
    def _cosine(a: np.ndarray, b: np.ndarray) -> float:
        a = a.astype(np.float32)
        b = b.astype(np.float32)
        na = np.linalg.norm(a) + 1e-9
        nb = np.linalg.norm(b) + 1e-9
        return float(np.dot(a, b) / (na * nb))

    def match(self, doc_photo_bgr: np.ndarray, selfie_bgr: np.ndarray) -> Dict[str, Any]:
        doc_faces = self.app.get(doc_photo_bgr)
        sel_faces = self.app.get(selfie_bgr)

        fd = self._largest_face(doc_faces)
        fs = self._largest_face(sel_faces)

        if fd is None:
            return {"score01": None, "cosine": None, "reason": "no_face_in_document_photo", "details": {"doc_faces": len(doc_faces)}}
        if fs is None:
            return {"score01": None, "cosine": None, "reason": "no_face_in_selfie", "details": {"selfie_faces": len(sel_faces)}}

        cos = self._cosine(fd.normed_embedding, fs.normed_embedding)
        score01 = (cos + 1.0) / 2.0
        return {"score01": float(score01), "cosine": float(cos), "reason": "ok", "details": None}


_FACE_MATCHER: Optional[FaceMatcherBase] = None
_FACE_LOCK = threading.Lock()


def get_face_matcher() -> FaceMatcherBase:
    global _FACE_MATCHER
    with _FACE_LOCK:
        if _FACE_MATCHER is not None:
            return _FACE_MATCHER
        try:
            _FACE_MATCHER = InsightFaceMatcher(prefer_gpu=False)
        except Exception:
            _FACE_MATCHER = FaceMatcherUnavailable()
        return _FACE_MATCHER


def compute_face_pack(
        *,
        doc_image_path: str,
        selfie_image_path: str,
        doc_photo_quad: Optional[List[List[float]]],
) -> Dict[str, Any]:
    """
    Computes face match pack once. Caller decides whether to use it depending on policy.
    """
    if not selfie_image_path:
        return {"score01": None, "cosine": None, "reason": "selfie_missing", "details": None}
    if not doc_photo_quad:
        return {"score01": None, "cosine": None, "reason": "doc_photo_not_detected", "details": None}

    doc_img = cv2.imread(doc_image_path)
    sel_img = cv2.imread(selfie_image_path)
    if doc_img is None:
        return {"score01": None, "cosine": None, "reason": "cannot_read_doc_image", "details": None}
    if sel_img is None:
        return {"score01": None, "cosine": None, "reason": "cannot_read_selfie_image", "details": None}

    crop = warp_quad_to_rect(doc_img, np.array(doc_photo_quad, dtype=np.float32))
    matcher = get_face_matcher()
    return matcher.match(crop, sel_img)


# ============================================================
# 6) Config schema (server-driven)
# ============================================================

@dataclass
class FieldRule:
    class_names: List[str]
    key: str
    required: bool

    min_det_conf: float = 0.25

    # NOTE: In LLM-OCR edition, this is the minimum confidence you expect
    # the LLM to report for the extracted value (if you use it).
    min_ocr_conf: float = 0.20

    ocr_kind: str = "text"  # none|text|numeric|date|mrz

    match_type: str = "optional"  # optional|exact|fuzzy|date
    match_threshold: float = 0.85

    expected_len: int = 0

    validator: Optional[str] = None
    constraints: Dict[str, Any] = field(default_factory=dict)

    weight: float = 1.0
    must_match: bool = False
    match_gate: bool = False

    max_candidates: int = 1

    input_aliases: Optional[List[str]] = None


@dataclass
class DocConfig:
    doc_id: str
    rules: List[FieldRule]

    # metadata (useful for country-aware validators and OCR locale hints)
    country: str = ""
    doc_type: str = ""
    ocr_locale_hint: str = "en"

    min_detected_fields_count: Optional[int] = None

    approve_min_coverage: float = 1.0
    approve_min_extraction: float = 0.78
    approve_min_match_core: float = 0.92
    approve_min_match_all: float = 0.0  # analytics only

    review_min_coverage: float = 0.75
    reject_below_coverage: float = 0.50

    approve_no_input_extra_buffer: float = 0.10

    # --- Face policy: ON/OFF (no optional) ---
    require_face_match: bool = False
    face_metric: str = "score01"   # score01|cosine
    face_match_threshold: float = 0.75

    # name swap is still useful for layout variants
    swap_pairs: List[Tuple[str, str]] = field(default_factory=list)
    enable_name_swap: bool = True
    name_swap_margin: float = 0.06


def doc_config_from_payload(payload: Dict[str, Any]) -> DocConfig:
    rules: List[FieldRule] = []
    for r in payload.get("rules", []):
        rr = dict(r)
        cn = rr.get("class_names") or rr.get("class_name")
        if isinstance(cn, str):
            rr["class_names"] = [cn]
        else:
            rr["class_names"] = list(cn or [])
        rr.pop("class_name", None)

        cst = rr.get("constraints")
        if cst is None:
            rr["constraints"] = {}
        elif isinstance(cst, str):
            try:
                rr["constraints"] = json.loads(cst)
            except Exception:
                rr["constraints"] = {}
        elif not isinstance(cst, dict):
            rr["constraints"] = {}

        rules.append(FieldRule(**rr))

    swap_pairs = payload.get("swap_pairs") or []
    swap_pairs = [tuple(x) for x in swap_pairs if isinstance(x, (list, tuple)) and len(x) == 2]

    return DocConfig(
        doc_id=str(payload.get("doc_id") or payload.get("id") or "doc"),
        rules=rules,

        country=str(payload.get("country") or "").upper(),
        doc_type=str(payload.get("doc_type") or ""),
        ocr_locale_hint=str(
            payload.get("translate")
            or payload.get("ocr_locale_hint")
            or payload.get("locale_hint")
            or default_locale_hint(str(payload.get("country") or ""))
        ),

        min_detected_fields_count=payload.get("min_detected_fields_count", None),

        approve_min_coverage=float(payload.get("approve_min_coverage", 1.0)),
        approve_min_extraction=float(payload.get("approve_min_extraction", 0.78)),
        approve_min_match_core=float(payload.get("approve_min_match_core", payload.get("approve_min_match", 0.92))),
        approve_min_match_all=float(payload.get("approve_min_match_all", 0.0)),

        review_min_coverage=float(payload.get("review_min_coverage", 0.75)),
        reject_below_coverage=float(payload.get("reject_below_coverage", 0.50)),
        approve_no_input_extra_buffer=float(payload.get("approve_no_input_extra_buffer", 0.10)),

        require_face_match=bool(payload.get("require_face_match", False)),
        face_metric=str(payload.get("face_metric", "score01")),
        face_match_threshold=float(payload.get("face_match_threshold", 0.75)),

        swap_pairs=swap_pairs,
        enable_name_swap=bool(payload.get("enable_name_swap", True)),
        name_swap_margin=float(payload.get("name_swap_margin", 0.06)),
    )


# ============================================================
# 7) Engine caching (YOLO only)
# ============================================================

_ENGINE_CACHE: Dict[str, "KYCEngine"] = {}
_ENGINE_LOCK = threading.Lock()


def get_engine(model_path: str) -> "KYCEngine":
    key = str(model_path)
    with _ENGINE_LOCK:
        if key not in _ENGINE_CACHE:
            _ENGINE_CACHE[key] = KYCEngine(model_path)
        return _ENGINE_CACHE[key]


class KYCEngine:
    def __init__(self, yolo_model_path: str):
        self.model = YOLO(yolo_model_path)

        names = self.model.names
        if isinstance(names, dict):
            self.class_names = {int(k): str(v) for k, v in names.items()}
        else:
            self.class_names = {i: str(n) for i, n in enumerate(names)}

        self._infer_lock = threading.Lock()

    @staticmethod
    def _clamp01(x: float) -> float:
        return max(0.0, min(1.0, float(x)))

    def _get_expected(self, rule: FieldRule, user_input: Dict[str, str]) -> Optional[str]:
        v = user_input.get(rule.key)
        if v:
            return v
        if rule.input_aliases:
            for ak in rule.input_aliases:
                vv = user_input.get(ak)
                if vv:
                    return vv
        return None

    def detect(self, image_path: str, conf: float, iou: float, device: str, max_det: int) -> List[Dict[str, Any]]:
        with self._infer_lock:
            preds = self.model.predict(
                source=image_path, conf=conf, iou=iou, device=device, max_det=max_det, verbose=False
            )

        if not preds:
            return []

        r = preds[0]
        dets: List[Dict[str, Any]] = []

        # OBB path
        if getattr(r, "obb", None) is not None and r.obb is not None:
            quads = r.obb.xyxyxyxy
            cls = r.obb.cls
            cf = r.obb.conf

            if isinstance(quads, torch.Tensor):
                quads = quads.detach().cpu().numpy()
            if isinstance(cls, torch.Tensor):
                cls = cls.detach().cpu().numpy()
            if isinstance(cf, torch.Tensor):
                cf = cf.detach().cpu().numpy()

            for i in range(len(quads)):
                cid = int(cls[i])
                dets.append({
                    "class_id": cid,
                    "class_name": self.class_names.get(cid, str(cid)),
                    "conf": float(cf[i]),
                    "quad": quads[i].astype(float).tolist(),
                })
            return dets

        # XYXY fallback
        if getattr(r, "boxes", None) is not None and r.boxes is not None:
            xyxy = r.boxes.xyxy
            cls = r.boxes.cls
            cf = r.boxes.conf

            if isinstance(xyxy, torch.Tensor):
                xyxy = xyxy.detach().cpu().numpy()
            if isinstance(cls, torch.Tensor):
                cls = cls.detach().cpu().numpy()
            if isinstance(cf, torch.Tensor):
                cf = cf.detach().cpu().numpy()

            for i in range(len(xyxy)):
                x1, y1, x2, y2 = [float(x) for x in xyxy[i]]
                cid = int(cls[i])
                quad = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
                dets.append({
                    "class_id": cid,
                    "class_name": self.class_names.get(cid, str(cid)),
                    "conf": float(cf[i]),
                    "quad": quad,
                })

        return dets

    # ---------------- LLM OCR request building ----------------

    @staticmethod
    def build_llm_ocr_prompt_bundle(
            *,
            fields_plan: List[Dict[str, Any]],
            locale_hint: str = "fa",
    ) -> Dict[str, Any]:
        """
        Returns a single, self-contained bundle to be sent to your LLM layer.
        You can convert it to OpenAI Responses API (or any other provider) on your server.

        fields_plan: list of:
          {
            "key": str,
            "kind": "text|numeric|date|mrz",
            "candidates": [{"idx": int, "det_conf": float, "image_data_url": str}, ...]
          }

        Expected LLM response:
          {
            "fields": {
              "<key>": {
                 "candidate_idx": 0,
                 "text": "...",          # what you read
                 "confidence": 0.0-1.0,  # optional but recommended
                 "normalized": "..."     # optional; server can re-normalize anyway
              }, ...
            }
          }
        """
        schema = {
            "schema_version": 1,
            "task": "kyc_field_ocr",
            "locale_hint": locale_hint,
            "fields": fields_plan,
            "output_contract": {
                "type": "json",
                "top_level_keys": ["fields"],
                "field_object": {
                    "candidate_idx": "int (required)",
                    "text": "string|null (required)",
                    "confidence": "float 0..1 (optional)",
                    "normalized": "string|null (optional)"
                }
            }
        }
        return schema

    @staticmethod
    def anthropic_content_items_from_bundle(bundle: Dict[str, Any]) -> List[Dict[str, Any]]:
        fields = bundle.get("fields") or []
        keys_list = [f.get("key") for f in fields]

        # >>> تغییر اصلی: هاردکد کردن دستورالعمل اعداد فارسی <<<
        # اینجا دیگر شرط نمی گذاریم، مستقیم به مدل دستور می دهیم
        intro = (
            "You are a specialized OCR engine for Iranian ID cards (National Card/Shenasnameh).\n"
            "CRITICAL INSTRUCTION FOR DIGITS:\n"
            "The images contain Persian/Arabic numerals (۰, ۱, ۲, ۳, ۴, ۵, ۶, ۷, ۸, ۹).\n"
            "You MUST transcribe them directly as standard English digits (0-9).\n"
            " - Visible: '۳۴۸' -> Output: '348'\n"
            " - Visible: '۱۴۰۲' -> Output: '1402'\n"
            "\n"
            "OUTPUT RULES:\n"
            "1. Return ONLY valid JSON. No markdown.\n"
            "2. If a field is blurry or unreadable, set 'text': null.\n"
            "3. Do not include any explanations.\n"
            "\n"
            "JSON Response Structure:\n"
            "{\n"
            "  \"fields\": {\n"
            "    \"KEY\": {\"text\": \"value\", \"confidence\": 0.99, \"candidate_idx\": 0}\n"
            "  }\n"
            "}\n"
            f"REQUIRED FIELDS: {', '.join(keys_list)}"
        )

        items = [{"type": "text", "text": intro}]

        for f in fields:
            key = f.get("key")
            kind = f.get("kind")
            cands = f.get("candidates") or []

            # راهنمایی خاص برای هر فیلد
            field_hint = f"\nTarget Field: '{key}' (Type: {kind})"
            if kind == "numeric":
                field_hint += " -> EXTRACT ONLY DIGITS (0-9)"

            items.append({"type": "text", "text": field_hint})

            # ارسال تصاویر (فقط نسخه Smart که در مرحله قبل تنظیم کردید)
            for c in cands:
                if "image_data_url" in c:
                    _, b64_data = c["image_data_url"].split(",", 1)
                    items.append({
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/png",
                            "data": b64_data,
                        },
                    })

        items.append({"type": "text", "text": "\nJSON Output:"})
        return items

    @staticmethod
    def openai_content_items_from_bundle(bundle: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Helper: convert bundle into OpenAI Responses API `input` content items in ONE request.
        (You can do the same on your server; provided for clarity.)

        Output is a list of items like:
          {"type":"input_text","text":"..."}
          {"type":"input_image","image_url":"data:image/jpeg;base64,..."}
        """
        fields = bundle.get("fields") or []
        # System-style instruction inside the user prompt (keeps it portable).
        intro = (
            "You are a deterministic OCR engine for KYC field crops.\n"
            "Return ONLY valid JSON. No markdown.\n"
            "\n"
            "Global rules:\n"
            "- Do NOT guess. If unreadable or ambiguous: text=null and confidence<=0.3.\n"
            "- Never invent missing characters.\n"
            "- Preserve native letters for names (do not transliterate).\n"
            "\n"
            "Numeric/date rules:\n"
            "- For kind=numeric: output ONLY ASCII digits 0-9 (no spaces, no separators).\n"
            "- For kind=date: output as seen; if you can normalize to YYYY-MM-DD, also set normalized.\n"
            "- NEVER drop leading zeros.\n"
            "\n"
            "Constraint handling (IMPORTANT):\n"
            "- Each FIELD may provide expected_len and/or regex.\n"
            "- If kind=numeric and expected_len>0: you MUST return exactly that many digits.\n"
            "- If regex is provided: your text MUST match it exactly.\n"
            "- If you cannot satisfy constraints with high confidence: text=null.\n"
            "\n"
            "Output JSON format:\n"
            "{ \"fields\": { \"<key>\": {\"candidate_idx\":0, \"text\":null, \"confidence\":0.0, \"normalized\":null} } }\n"
        )



        items: List[Dict[str, Any]] = [{"type": "input_text", "text": intro}]
        # Keep a deterministic ordering for mapping images.
        for f in fields:
            key = f.get("key")
            kind = f.get("kind")
            cands = f.get("candidates") or []
            expected_len = int(f.get("expected_len") or 0)
            regex = (f.get("regex") or "").strip()
            items.append({
                "type": "input_text",
                "text": f"\nFIELD key={key} kind={kind} expected_len={expected_len} regex={regex} candidates={len(cands)}"
            })

            for c in cands:
                idx = c.get("idx")
                det_conf = c.get("det_conf")
                items.append({"type": "input_text", "text": f"Candidate idx={idx} det_conf={det_conf}"})

                items.append({"type": "input_image", "image_url": c.get("image_data_url")})
        items.append({"type": "input_text", "text": "\nReturn ONLY JSON. No markdown."})
        return items

    @staticmethod
    def gemini_parts_from_bundle(bundle: Dict[str, Any]) -> List[Dict[str, Any]]:
        fields = bundle.get("fields") or []
        keys_list = [str(f.get("key")) for f in fields if isinstance(f.get("key"), str) and f.get("key")]

        intro = (
            "You are a deterministic OCR engine for KYC field crops.\n"
            "Return ONLY valid JSON. No markdown.\n"
            "\n"
            "Global rules:\n"
            "- Do NOT guess. If unreadable or ambiguous: text=null and confidence<=0.3.\n"
            "- Never invent missing characters.\n"
            "- Preserve native letters for names (do not transliterate).\n"
            "\n"
            "Numeric/date rules:\n"
            "- For kind=numeric: output ONLY ASCII digits 0-9 (no spaces, no separators).\n"
            "- For kind=date: output as seen; if you can normalize to YYYY-MM-DD, also set normalized.\n"
            "- NEVER drop leading zeros.\n"
            "\n"
            "Constraint handling (IMPORTANT):\n"
            "- Each field may provide expected_len and/or regex.\n"
            "- If kind=numeric and expected_len>0: you MUST return exactly that many digits.\n"
            "- If regex is provided: your text MUST match it exactly.\n"
            "- If you cannot satisfy constraints with high confidence: text=null.\n"
            "\n"
            "Output JSON format:\n"
            '{\"fields\": {\"<key>\": {\"candidate_idx\":0, \"text\": null, \"confidence\":0.0, \"normalized\": null}}}\n'
            f"REQUIRED FIELDS: {', '.join(keys_list)}"
        )

        parts: List[Dict[str, Any]] = [{"text": intro}]
        data_url_re = re.compile(r"^data:([^;]+);base64,")

        for f in fields:
            if not isinstance(f, dict):
                continue

            key = f.get("key")
            if not isinstance(key, str) or not key.strip():
                continue

            kind = (f.get("kind") or "text").lower()
            cands = f.get("candidates") or []
            expected_len = f.get("expected_len")
            regex = (f.get("regex") or "").strip()

            try:
                expected_len_i = int(expected_len or 0)
            except Exception:
                expected_len_i = 0

            parts.append({
                "text": f"\nFIELD key={key} kind={kind} expected_len={expected_len_i} regex={regex} candidates={len(cands)}"
            })

            for c in cands:
                if not isinstance(c, dict):
                    continue

                idx = c.get("idx")
                det_conf = c.get("det_conf")
                image_data_url = c.get("image_data_url", "")

                if not isinstance(image_data_url, str) or "base64," not in image_data_url:
                    continue

                parts.append({"text": f"Candidate idx={idx} det_conf={det_conf}"})

                m = data_url_re.match(image_data_url.strip())
                mime = "image/png"
                if m:
                    mime = m.group(1) or mime
                _, data = image_data_url.split("base64,", 1)
                parts.append({"inline_data": {"mime_type": mime, "data": data}})

        parts.append({"text": "Return ONLY JSON. No markdown."})
        return parts

    def extract_with_llm(
            self,
            doc_image_path: str,
            doc_cfg: DocConfig,
            dets: List[Dict[str, Any]],
            llm_ocr: Callable[[Dict[str, Any]], Dict[str, Any]],
            locale_hint: Optional[str] = None,
            user_input: Optional[Dict[str, str]] = None,
            debug: bool = False,
            jpeg_quality: int = _OCR_JPEG_QUALITY,
            max_dim: int = _OCR_MAX_DIM,
    ) -> Dict[str, Any]:
        """
        1) Uses YOLO detections + rules to crop candidates.
        2) Builds ONE bundle and calls llm_ocr(bundle).
        3) Maps response back to per-field extraction pack (same shape as before).
        """
        user_input = user_input or {}

        img = cv2.imread(doc_image_path)
        if img is None:
            raise RuntimeError(f"Cannot read image: {doc_image_path}")

        by_class: Dict[str, List[Dict[str, Any]]] = {}
        for d in dets:
            by_class.setdefault(d["class_name"], []).append(d)
        for k in by_class:
            by_class[k].sort(key=lambda x: -float(x.get("conf", 0.0)))

        # Keep candidate structures so we can map LLM's candidate_idx back to quad.
        candidates_for_key: Dict[str, List[Dict[str, Any]]] = {}

        def top_k(rule: FieldRule) -> int:
            k = int(rule.max_candidates or 1)
            if k <= 0:
                k = 1
            return min(k, 6)

        # ---- Build fields_plan for LLM (only OCR kinds) ----
        fields_plan: List[Dict[str, Any]] = []
        internals: Dict[str, Any] = {"quads": {}, "doc_image_shape": list(img.shape[:2])}
        fields: Dict[str, Any] = {}

        for rule in doc_cfg.rules:
            # collect detections for any matching class_name
            candidates: List[Dict[str, Any]] = []
            for cn in rule.class_names:
                candidates.extend(by_class.get(cn, []))
            if not candidates:
                continue

            candidates = [c for c in candidates if float(c.get("conf", 0.0)) >= float(rule.min_det_conf)]
            candidates.sort(key=lambda x: -float(x.get("conf", 0.0)))
            candidates = candidates[:top_k(rule)]
            if not candidates:
                continue

            kind = (rule.ocr_kind or "text").lower()

            # For "none" we don't send to LLM
            if kind == "none":
                best = candidates[0]
                fields[rule.key] = {
                    "class_name": rule.class_names[0],
                    "det_conf": float(best["conf"]),
                    "ocr_conf": None,
                    "value_raw": None,
                    "value_norm": None,
                    "ocr_method": None,
                    "quad": best["quad"],
                    "candidate_idx": 0,
                }
                internals["quads"][rule.key] = best["quad"]
                continue

            # OCR fields: crop + encode candidates


            expected_len = int(
                rule.expected_len
                or (rule.constraints or {}).get("length")
                or 0
            )

            regex = str((rule.constraints or {}).get("regex") or "").strip()

            plan_entry = {
                "key": rule.key,
                "kind": kind,
                "expected_len": expected_len,   # ✅ from server/policy
                "regex": regex,                 # ✅ from server/policy
                "candidates": []
            }

            candidates_for_key[rule.key] = []
            debug_dir = "llm_ocr_debug"
            if not os.path.exists(debug_dir):
                os.makedirs(debug_dir)

            for det in candidates:
                quad = np.array(det["quad"], dtype=np.float32)
                quad = expand_quad(quad, 1.15 if kind in ("numeric", "date") else 1.08)
                crop = warp_quad_to_rect(img, quad)

                # OCR only with sharpen preprocessing.
                variants = ["sharpen"]


                for v in variants:
                    idx = len(plan_entry["candidates"])  #  این idx جدید است
                    data_url = _encode_image_for_llm(crop, kind=kind, variant=v, max_dim=max_dim, jpeg_quality=jpeg_quality)
                    try:
                        img_filename = f"{rule.key}_{v}_{idx}.jpg"
                        img_path = os.path.join(debug_dir, img_filename)
                        header, encoded = data_url.split(",", 1)
                        with open(img_path, "wb") as f:
                            f.write(base64.b64decode(encoded))
                    except Exception as e:
                        print(f"Failed to save debug image: {e}")

                    plan_entry["candidates"].append({
                        "idx": idx,
                        "det_conf": float(det["conf"]),
                        "image_data_url": data_url,
                    })
                    candidates_for_key[rule.key].append(det)
            if plan_entry["candidates"]:
                fields_plan.append(plan_entry)

        # ---- ONE request to LLM ----
        lh = str(locale_hint or doc_cfg.ocr_locale_hint or default_locale_hint(doc_cfg.country))
        bundle = self.build_llm_ocr_prompt_bundle(fields_plan=fields_plan, locale_hint=lh)
        llm_resp = llm_ocr(bundle) or {}
        llm_fields = llm_resp.get("fields") if isinstance(llm_resp, dict) else None
        if not isinstance(llm_fields, dict):
            llm_fields = {}

        # ---- Map results back into engine output ----
        def normalize_value(rule: FieldRule, kind: str, value: str) -> str:
            if kind == "numeric":
                return normalize_digits_only(value)
            if kind == "date":
                return collapse_spaces(to_en_digits(value))
            if rule.key in ("first_name", "last_name", "full_name"):
                return normalize_name(value)
            if rule.key in ("id_number", "passport_no"):
                return normalize_digits_only(value) if kind == "numeric" else normalize_id_alphanum(value)
            return normalize_fa_text(value)

        def heuristic_conf(kind: str, text: Optional[str], expected_len: int = 0) -> float:
            t = (text or "").strip()
            if not t:
                return 0.0
            if kind == "numeric":
                d = normalize_digits_only(t)
                if expected_len and len(d) == expected_len:
                    return 0.9
                return 0.65
            if kind == "date":
                if parse_date_loose(t):
                    return 0.85
                return 0.55
            if kind == "mrz":
                if "<" in t and len(t.splitlines()) >= 2:
                    return 0.85
                return 0.55
            return 0.70

        for rule in doc_cfg.rules:
            kind = (rule.ocr_kind or "text").lower()
            if kind == "none":
                continue
            # was it planned?
            if rule.key not in candidates_for_key:
                continue

            resp = llm_fields.get(rule.key) or {}
            if not isinstance(resp, dict):
                resp = {}

            cand_idx = resp.get("candidate_idx")
            try:
                cand_idx = int(cand_idx)
            except Exception:
                cand_idx = 0

            det_list = candidates_for_key.get(rule.key) or []
            if not det_list:
                continue
            if cand_idx < 0 or cand_idx >= len(det_list):
                cand_idx = 0

            det = det_list[cand_idx]
            det_conf = float(det.get("conf", 0.0))

            raw_text = resp.get("text")
            if raw_text is None:
                raw_text = resp.get("value")
            if raw_text is None:
                raw_text = ""
            raw_text = str(raw_text)

            conf = resp.get("confidence")
            if conf is None:
                conf = heuristic_conf(kind, raw_text, expected_len=int(rule.expected_len or rule.constraints.get("length") or 0))
            try:
                conf = float(conf)
            except Exception:
                conf = heuristic_conf(kind, raw_text, expected_len=int(rule.expected_len or rule.constraints.get("length") or 0))
            conf = self._clamp01(conf)

            norm_text = resp.get("normalized")
            if norm_text is None or str(norm_text).strip() == "":
                norm_text = normalize_value(rule, kind, raw_text)
            else:
                norm_text = normalize_value(rule, kind, str(norm_text))

            fields[rule.key] = {
                "class_name": rule.class_names[0],
                "det_conf": det_conf,
                "ocr_conf": conf,
                "value_raw": raw_text.strip(),
                "value_norm": norm_text,
                "ocr_method": "llm",
                "quad": det["quad"],
                "candidate_idx": cand_idx,
            }
            internals["quads"][rule.key] = det["quad"]

        out = {"fields": fields, "internals": internals}
        if debug:
            # out["detections"] = dets # این خط را کامنت کنید چون اطلاعات خام YOLO خیلی زیاد است

            # --- کد جدید برای حذف عکس‌های Base64 از لاگ ---
            import copy
            clean_bundle = copy.deepcopy(bundle)
            try:
                for f in clean_bundle.get("fields", []):
                    for c in f.get("candidates", []):
                        if "image_data_url" in c:
                            c["image_data_url"] = "<BASE64_REMOVED_FOR_LOG>"
            except Exception:
                pass

            out["llm_bundle"] = clean_bundle
            out["llm_response_raw"] = llm_resp
            # -----------------------------------------------

        return out

    # ---------------- Scoring ----------------

    def score(
            self,
            doc_cfg: DocConfig,
            extracted_fields: Dict[str, Any],
            user_input: Optional[Dict[str, str]],
            face_pack: Optional[Dict[str, Any]],
            debug: bool
    ) -> Dict[str, Any]:
        user_input = user_input or {}

        required_rules = [r for r in doc_cfg.rules if r.required]

        found_required_det = 0
        det_ok_keys: set = set()

        extraction_pairs: List[Tuple[float, float]] = []
        match_pairs_all: List[Tuple[float, float]] = []
        match_pairs_core: List[Tuple[float, float]] = []

        per_field: Dict[str, Any] = {}
        mismatch_flags: List[str] = []
        invalid_flags: List[str] = []

        def validate_constraints(rule: FieldRule, kind: str, raw: Optional[str], norm: Optional[str]) -> Tuple[bool, Optional[str]]:
            c = rule.constraints or {}
            if not c:
                return True, None

            raw_s = raw or ""
            norm_s = norm or ""

            length = c.get("length")
            min_length = c.get("min_length")
            max_length = c.get("max_length")

            if length is not None:
                try:
                    L = int(length)
                    if len(norm_s) != L:
                        return False, f"length_expected_{L}"
                except Exception:
                    pass
            if min_length is not None:
                try:
                    L = int(min_length)
                    if len(norm_s) < L:
                        return False, f"min_length_{L}"
                except Exception:
                    pass
            if max_length is not None:
                try:
                    L = int(max_length)
                    if len(norm_s) > L:
                        return False, f"max_length_{L}"
                except Exception:
                    pass

            prefix = c.get("prefix")
            if prefix:
                pref = str(prefix)
                if not str(norm_s).startswith(pref):
                    return False, f"prefix_{pref}"

            regex = c.get("regex")
            if regex:
                try:
                    if re.fullmatch(str(regex), str(norm_s)) is None:
                        return False, "regex_mismatch"
                except re.error:
                    pass

            if kind == "date" and c.get("must_parse", False):
                if parse_date_loose(raw_s) is None:
                    return False, "date_unparsed"

            return True, None

        def wavg(pairs: List[Tuple[float, float]]) -> float:
            if not pairs:
                return 0.0
            sw = sum(w for _, w in pairs)
            return float(sum(v * w for v, w in pairs) / sw) if sw > 0 else 0.0

        for rule in doc_cfg.rules:
            f = extracted_fields.get(rule.key)
            if not f:
                per_field[rule.key] = {
                    "present": False, "det_ok": False, "ocr_ok": False, "extraction": 0.0,
                    "match": None, "reason": "missing"
                }
                continue

            det_conf = self._clamp01(float(f.get("det_conf", 0.0)))
            det_ok = det_conf >= float(rule.min_det_conf)
            if det_ok:
                det_ok_keys.add(rule.key)

            kind = (rule.ocr_kind or "text").lower()

            if kind == "none":
                ocr_ok = True
                extraction = math.sqrt(det_conf * 1.0)
                ocr_conf = None
                raw_value = None
                norm_value = None
            else:
                ocr_conf = self._clamp01(float(f.get("ocr_conf", 0.0) or 0.0))
                raw_value = f.get("value_raw") or ""
                norm_value = f.get("value_norm") or ""
                ocr_ok = (ocr_conf >= float(rule.min_ocr_conf)) and bool(str(norm_value).strip())
                extraction = math.sqrt(det_conf * max(0.01, ocr_conf))

            if rule.required and det_ok:
                found_required_det += 1

            valid_ok = True
            valid_reason = None
            if kind != "none" and rule.validator:
                ok, reason = run_validator(rule.validator, str(norm_value or ""), rule.constraints or {})
                if not ok:
                    valid_ok = False
                    valid_reason = reason or f"{rule.validator}_invalid"

            c_ok, c_reason = validate_constraints(rule, kind, raw_value, norm_value)
            if not c_ok:
                valid_ok = False
                valid_reason = valid_reason or c_reason

            if not valid_ok:
                invalid_flags.append(f"invalid:{rule.key}:{valid_reason}")

            extraction_pairs.append((float(extraction), float(rule.weight)))

            expected = self._get_expected(rule, user_input)
            mscore: Optional[float] = None
            mreason = "no_user_input"

            if expected and kind != "none":
                mt = (rule.match_type or "optional").lower()

                if mt == "exact":
                    got_norm = norm_value or ""
                    exp_norm = expected

                    if kind == "numeric":
                        got_norm = normalize_digits_only(got_norm)
                        exp_norm = normalize_digits_only(exp_norm)
                    elif rule.key in ("first_name", "last_name", "full_name"):
                        got_norm = normalize_name(got_norm)
                        exp_norm = normalize_name(exp_norm)
                    else:
                        got_norm = normalize_fa_text(got_norm)
                        exp_norm = normalize_fa_text(exp_norm)

                    mscore = 1.0 if got_norm == exp_norm else 0.0
                    mreason = "exact_match" if mscore == 1.0 else "exact_mismatch"

                elif mt == "fuzzy":
                    mscore = float(token_set_similarity(raw_value or "", expected))
                    mreason = "fuzzy" if mscore >= float(rule.match_threshold) else "fuzzy_below_threshold"

                elif mt == "date":
                    got = parse_date_loose(raw_value or "")
                    exp = parse_date_loose(expected)
                    if got and exp:
                        mscore = 1.0 if got == exp else 0.0
                        mreason = "date_match" if mscore == 1.0 else "date_mismatch"
                    else:
                        mscore = None
                        mreason = "date_unparsed"

            if mscore is not None:
                match_pairs_all.append((float(mscore), float(rule.weight)))
                if rule.match_gate:
                    match_pairs_core.append((float(mscore), float(rule.weight)))

            if rule.must_match:
                if mscore == 0.0 and mreason in ("exact_mismatch", "date_mismatch"):
                    mismatch_flags.append(f"mismatch:{rule.key}")
                if mreason == "fuzzy_below_threshold":
                    mismatch_flags.append(f"mismatch:{rule.key}")

            rec: Dict[str, Any] = {
                "present": det_ok,
                "det_ok": det_ok,
                "ocr_ok": ocr_ok,
                "det_conf": det_conf,
                "ocr_conf": ocr_conf,
                "extraction": float(extraction),
                "match": mscore,
                "reason": mreason,
                "valid": valid_ok,
                "valid_reason": valid_reason,
            }
            if debug:
                rec.update({
                    "raw": raw_value,
                    "normalized": norm_value,
                    "class_name": f.get("class_name"),
                    "ocr_method": f.get("ocr_method"),
                    "quad": f.get("quad"),
                    "candidate_idx": f.get("candidate_idx"),
                })
            per_field[rule.key] = rec

        coverage = (found_required_det / max(1, len(required_rules))) if required_rules else 1.0
        extraction_score = wavg(extraction_pairs)
        match_all = (wavg(match_pairs_all) if match_pairs_all else None)
        match_core = (wavg(match_pairs_core) if match_pairs_core else None)

        count_gate_ok = True
        if doc_cfg.min_detected_fields_count is not None:
            count_gate_ok = (len(det_ok_keys) >= int(doc_cfg.min_detected_fields_count))

        face_val = None
        face_reason = None
        if face_pack is not None:
            face_reason = face_pack.get("reason")
            if doc_cfg.face_metric == "cosine":
                face_val = face_pack.get("cosine")
            else:
                face_val = face_pack.get("score01")

        base_match = float(match_core if match_core is not None else (match_all or 0.0))
        doc_score = self._clamp01(0.5 * float(extraction_score) + 0.5 * base_match)

        # Face is ON/OFF only:
        final_score = doc_score
        if doc_cfg.require_face_match:
            if face_val is not None:
                final_score = self._clamp01(0.65 * doc_score + 0.35 * self._clamp01(float(face_val)))
            else:
                final_score = doc_score

        decision, reasons = decide(
            doc_cfg=doc_cfg,
            coverage=coverage,
            extraction=extraction_score,
            match_all=match_all,
            match_core=match_core,
            count_gate_ok=count_gate_ok,
            detected_count=len(det_ok_keys),
            mismatch_flags=mismatch_flags,
            invalid_flags=invalid_flags,
            face_value=face_val,
            face_reason=face_reason,
        )

        return {
            "decision": decision,
            "reasons": reasons,
            "scores": {
                "coverage": coverage,
                "extraction": extraction_score,
                "match_all": match_all,
                "match_core": match_core,
                "detected_fields_count": len(det_ok_keys),
                "count_gate_ok": count_gate_ok,
                "face": face_val,
                "face_metric": doc_cfg.face_metric,
                "doc_score": doc_score,
                "final_score": final_score,
            },
            "per_field": per_field,
        }


# ============================================================
# 8) Decision logic (no optional selfie rescue)
# ============================================================

def decide(
        doc_cfg: DocConfig,
        coverage: float,
        extraction: float,
        match_all: Optional[float],
        match_core: Optional[float],
        count_gate_ok: bool,
        detected_count: int,
        mismatch_flags: List[str],
        invalid_flags: List[str],
        face_value: Optional[float],
        face_reason: Optional[str],
) -> Tuple[str, List[str]]:
    reasons: List[str] = []

    if doc_cfg.min_detected_fields_count is not None and not count_gate_ok:
        reasons.append("detected_fields_count_too_low")
        reasons.append(f"detected={detected_count}_min={doc_cfg.min_detected_fields_count}")
        return "REJECT", reasons

    if coverage < doc_cfg.reject_below_coverage:
        reasons.append("coverage_too_low")
        return "REJECT", reasons

    # Invalid / mismatch => do NOT approve
    if invalid_flags:
        reasons.extend(invalid_flags[:12])
        return "REVIEW", reasons

    if mismatch_flags:
        reasons.extend(mismatch_flags[:12])
        return "REVIEW", reasons

    # Face policy ON/OFF
    if doc_cfg.require_face_match:
        if face_value is None:
            reasons.append(f"face_required_but_unavailable:{face_reason or 'unknown'}")
            return "REVIEW", reasons
        if float(face_value) < float(doc_cfg.face_match_threshold):
            reasons.append("face_below_threshold")
            reasons.append(f"face={face_value:.3f}_thr={doc_cfg.face_match_threshold:.3f}")
            return "REVIEW", reasons

    # Approve path
    if coverage >= doc_cfg.approve_min_coverage and extraction >= doc_cfg.approve_min_extraction:
        if match_core is None:
            if extraction >= (doc_cfg.approve_min_extraction + doc_cfg.approve_no_input_extra_buffer):
                return "APPROVE", reasons
            reasons.append("no_core_user_input")
            return "REVIEW", reasons

        if match_core >= doc_cfg.approve_min_match_core:
            return "APPROVE", reasons

        reasons.append("core_match_below_threshold")
        if match_all is not None:
            reasons.append(f"match_all={match_all:.3f}")
        return "REVIEW", reasons

    # Review / reject fallback
    if coverage >= doc_cfg.review_min_coverage:
        if extraction < doc_cfg.approve_min_extraction:
            reasons.append("extraction_low")
        if coverage < doc_cfg.approve_min_coverage:
            reasons.append("partial_coverage")
        return "REVIEW", reasons

    reasons.append("coverage_insufficient")
    return "REVIEW", reasons


# ============================================================
# 9) Swap utilities (same as before)
# ============================================================

def swap_field_keys(fields: Dict[str, Any], internals: Dict[str, Any], a: str, b: str, reason: str) -> None:
    if a in fields and b in fields:
        fields[a], fields[b] = fields[b], fields[a]
    q = internals.get("quads", {})
    if a in q and b in q:
        q[a], q[b] = q[b], q[a]
    internals.setdefault("swap_events", []).append({"a": a, "b": b, "reason": reason})


def apply_config_swaps(fields: Dict[str, Any], internals: Dict[str, Any], cfg: DocConfig) -> None:
    for a, b in cfg.swap_pairs:
        swap_field_keys(fields, internals, a, b, reason="config_swap_pairs")


def best_name_assignment(
        ocr_first_raw: str,
        ocr_last_raw: str,
        exp_first: str,
        exp_last: str,
        margin: float = 0.06
) -> Dict[str, Any]:
    s11 = token_set_similarity(ocr_first_raw or "", exp_first or "")
    s22 = token_set_similarity(ocr_last_raw or "", exp_last or "")
    normal = (s11 + s22) / 2.0

    s12 = token_set_similarity(ocr_first_raw or "", exp_last or "")
    s21 = token_set_similarity(ocr_last_raw or "", exp_first or "")
    swapped = (s12 + s21) / 2.0

    if swapped > normal + margin:
        return {"swap": True, "score_normal": normal, "score_swapped": swapped, "details": {"first->last": s12, "last->first": s21}}
    return {"swap": False, "score_normal": normal, "score_swapped": swapped, "details": {"first->first": s11, "last->last": s22}}


def apply_name_swap_if_needed(fields: Dict[str, Any], internals: Dict[str, Any], cfg: DocConfig, user_input: Dict[str, str]) -> None:
    if not cfg.enable_name_swap:
        return
    if "first_name" not in fields or "last_name" not in fields:
        return
    exp_first = user_input.get("first_name") or ""
    exp_last = user_input.get("last_name") or ""
    if not exp_first or not exp_last:
        return

    f_raw = fields["first_name"].get("value_raw") or ""
    l_raw = fields["last_name"].get("value_raw") or ""

    info = best_name_assignment(f_raw, l_raw, exp_first, exp_last, margin=cfg.name_swap_margin)
    if info.get("swap"):
        swap_field_keys(fields, internals, "first_name", "last_name", reason="heuristic_name_swap")
        internals.setdefault("swap_events", []).append({"a": "first_name", "b": "last_name", "reason": "heuristic_name_swap", "meta": info})


# ============================================================
# 10) Public API
# ============================================================

def run_kyc(
        *,
        model_path: str,
        doc_image_path: str,
        doc_id: str,
        doc_config_payload: Optional[Dict[str, Any]] = None,
        doc_config: Optional[DocConfig] = None,

        # Optional OCR hook (if omitted, Eden AI callback is used)
        llm_ocr: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,

        user_input: Optional[Dict[str, str]] = None,
        selfie_image_path: Optional[str] = None,

        yolo_conf: float = 0.25,
        yolo_iou: float = 0.6,
        device: str = "0",
        max_det: int = 200,

        cache_engine: bool = True,
        debug: bool = False,
) -> Dict[str, Any]:
    if doc_config_payload is not None:
        cfg = doc_config_from_payload(doc_config_payload)
    elif doc_config is not None:
        cfg = doc_config
    else:
        raise ValueError("doc_config_payload or doc_config must be provided (server-driven).")

    if llm_ocr is None:
        llm_ocr = eden_llm_ocr

    engine = get_engine(model_path) if cache_engine else KYCEngine(model_path)

    dets = engine.detect(doc_image_path, conf=yolo_conf, iou=yolo_iou, device=device, max_det=max_det)

    pack = engine.extract_with_llm(
        doc_image_path,
        cfg,
        dets,
        llm_ocr=llm_ocr,
        user_input=user_input,
        debug=debug
    )
    fields = pack["fields"]
    internals = pack["internals"]

    apply_config_swaps(fields, internals, cfg)
    apply_name_swap_if_needed(fields, internals, cfg, user_input or {})

    # ---------------- Face (ON/OFF) ----------------
    face_pack = None
    if cfg.require_face_match:
        if selfie_image_path:
            quad = internals["quads"].get("doc_photo") or internals["quads"].get("photo")
            face_pack = compute_face_pack(doc_image_path=doc_image_path, selfie_image_path=selfie_image_path, doc_photo_quad=quad)
        else:
            face_pack = {"score01": None, "cosine": None, "reason": "selfie_missing", "details": None}

    scoring = engine.score(cfg, fields, user_input=user_input, face_pack=face_pack, debug=debug)

    result: Dict[str, Any] = {
        "doc_id": doc_id,
        "decision": scoring["decision"],
        "scores": scoring["scores"],
        "reasons": scoring["reasons"],
        "per_field": scoring["per_field"],
    }

    # next_steps / selfie_mode
    selfie_mode = "required" if cfg.require_face_match else "disabled"
    next_steps: List[str] = []
    if cfg.require_face_match and not selfie_image_path:
        next_steps.append("upload_selfie_required")
    if result["decision"] == "REJECT":
        next_steps.append("retake_document_photo")

    result["selfie_mode"] = selfie_mode
    result["next_steps"] = next_steps

    if debug:
        result["detections"] = dets
        result["fields"] = fields
        result["internals"] = internals
        result["face"] = face_pack
        result["llm_bundle"] = pack.get("llm_bundle")
        result["llm_response_raw"] = pack.get("llm_response_raw")

    return result
