
# policy_provider_llm.py
# ------------------------------------------------------------
# Fetch variants from policy server + build DocConfig payloads from validation_logic
# Runs YOLO ONCE + LLM OCR ONCE (union of fields) + scores all variants.
#
from __future__ import annotations

import json
import time
import copy
import os
import logging
import urllib.parse
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Callable

from identity_registry import get_identity_defaults

def _extract_yolo_model_path(payload: Dict[str, Any]) -> Optional[str]:
    """
    Your PHP output:
    {
      "status":"success",
      "data":{
        "config":{"yolo_model_path":"C:\\\\...\\\\best.pt"},
        "rules":[{"yolo_model_path":"C:\\\\...\\\\best.pt", ...}]
      }
    }
    """
    data = payload.get("data")
    if isinstance(data, dict):
        cfg = data.get("config")
        if isinstance(cfg, dict):
            p = cfg.get("yolo_model_path")
            if p:
                return str(p)

        rules = data.get("rules")
        if isinstance(rules, list) and rules and isinstance(rules[0], dict):
            p = rules[0].get("yolo_model_path")
            if p:
                return str(p)

    # fallback if someday you move it to top-level
    cfg = payload.get("config")
    if isinstance(cfg, dict):
        p = cfg.get("yolo_model_path")
        if p:
            return str(p)

    return None

def _http_get_json(url: str, params: Dict[str, Any], timeout_s: float = 12.0, retries: int = 2) -> Dict[str, Any]:
    import os
    import gzip
    import urllib.error

    qs = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None})
    full = url + ("&" if "?" in url else "?") + qs

    headers = {
        "User-Agent": (
            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
            "AppleWebKit/537.36 (KHTML, like Gecko) "
            "Chrome/124.0.0.0 Safari/537.36"
        ),
        "Accept": "application/json, text/plain, */*",
        "Accept-Language": "en-US,en;q=0.9,fa;q=0.8",
        "Referer": "https://chat.fxtrendo.com/",
        "Connection": "close",
        "Accept-Encoding": "gzip, deflate",
    }

    token = os.environ.get("KYC_RULES_TOKEN")
    if token:
        headers["Authorization"] = f"Bearer {token}"

    last_err: Optional[Exception] = None

    for i in range(retries + 1):
        try:
            req = urllib.request.Request(full, headers=headers, method="GET")
            with urllib.request.urlopen(req, timeout=timeout_s) as resp:
                status = getattr(resp, "status", None) or resp.getcode()
                final_url = resp.geturl()
                ctype = resp.headers.get("Content-Type", "")
                enc = (resp.headers.get("Content-Encoding", "") or "").lower()
                raw = resp.read()

            if enc == "gzip":
                try:
                    raw = gzip.decompress(raw)
                except Exception:
                    pass

            try:
                text = raw.decode("utf-8")
            except UnicodeDecodeError:
                text = raw.decode("latin-1", errors="replace")

            t = (text or "").lstrip()
            if not t:
                raise RuntimeError(f"Empty response body (status={status}, url={final_url}, ctype={ctype})")

            if t.startswith("<") and ("<html" in t[:200].lower() or "cloudflare" in t[:400].lower()):
                preview = t[:600].replace("\n", "\\n")
                raise RuntimeError(
                    "Non-JSON response (likely WAF/Cloudflare HTML). "
                    f"status={status}, url={final_url}, ctype={ctype}, preview={preview}"
                )

            try:
                return json.loads(text)
            except json.JSONDecodeError as je:
                preview = t[:600].replace("\n", "\\n")
                raise RuntimeError(
                    f"JSON decode failed. status={status}, url={final_url}, ctype={ctype}, "
                    f"preview={preview}"
                ) from je

        except urllib.error.HTTPError as e:
            last_err = e
            try:
                body = e.read()
                try:
                    body_text = body.decode("utf-8", errors="replace")
                except Exception:
                    body_text = str(body)
                preview = (body_text or "")[:600].replace("\n", "\\n")
            except Exception:
                preview = ""

            if e.code == 403:
                raise RuntimeError(f"403 Forbidden from policy server. url={full}, preview={preview}") from e

            if i < retries:
                time.sleep(0.4 * (2 ** i))
                continue
            raise RuntimeError(f"HTTP error from {full}: {e}, preview={preview}") from e

        except Exception as e:
            last_err = e
            if i < retries:
                time.sleep(0.4 * (2 ** i))
                continue
            raise RuntimeError(f"Failed to fetch JSON from {full}: {e}") from e

    raise RuntimeError(f"Failed to fetch JSON from {full}: {last_err}")


def _coerce_json(v: Any) -> Any:
    if v is None:
        return None
    if isinstance(v, (dict, list)):
        return v
    if isinstance(v, str):
        s = v.strip()
        if not s:
            return None
        if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
            try:
                return json.loads(s)
            except Exception:
                return v
    return v


def _find_rules_list(payload: Dict[str, Any]) -> List[Dict[str, Any]]:
    if isinstance(payload.get("data"), dict) and isinstance(payload["data"].get("rules"), list):
        return payload["data"]["rules"]
    if isinstance(payload.get("rules"), list):
        return payload["rules"]
    if isinstance(payload.get("data"), list):
        lst = payload["data"]
        if lst and isinstance(lst[0], dict) and ("required_steps" in lst[0] or "validation_logic" in lst[0]):
            return lst

    for _, v in payload.items():
        if isinstance(v, dict) and isinstance(v.get("rules"), list):
            return v["rules"]
        if isinstance(v, list) and v and isinstance(v[0], dict):
            if "required_steps" in v[0] or "validation_logic" in v[0]:
                return v
    return []


# --- Canonical mapping: YOLO class_name -> internal field key ---
_CANON_KEYS = {
    # old + underscore variants
    "Document_Body": "document_body",
    "Document Body": "document_body",

    "Photo": "doc_photo",
    "Doc Photo": "doc_photo",

    "Doc Logo": "doc_logo",
    "Doc_Logo": "doc_logo",

    "Full Name": "full_name",
    "Full_Name": "full_name",

    "First Name": "first_name",
    "First_Name": "first_name",

    "Last Name": "last_name",
    "Last_Name": "last_name",

    "ID Number": "id_number",
    "ID_Number": "id_number",

    "Passport No": "passport_no",
    "Passport_No": "passport_no",

    "Birth Date": "birth_date",
    "Birth_Date": "birth_date",

    "Expiry Date": "expiry_date",
    "Expiry_Date": "expiry_date",

    "Issue Date": "issue_date",
    "Issue_Date": "issue_date",

    "Gender": "gender",
    "Nationality": "nationality",

    "MRZ Zone": "mrz",
    "MRZ_Zone": "mrz",

    "Address": "address",

    # NEW classes
    "QR_Barcode": "qr_barcode",
    "QR Barcode": "qr_barcode",
    "Issuing_Authority": "issuing_authority",
    "Issuing Authority": "issuing_authority",
    "Place_of_Birth": "place_of_birth",
    "Place of Birth": "place_of_birth",
}

_CORE_KEYS = {"first_name", "last_name", "full_name", "id_number", "passport_no"}


def _infer_ocr_locale_hint(country: str, defaults: Dict[str, Any]) -> str:
    """Used to pick OCR language hint for the LLM."""
    v = defaults.get("ocr_locale_hint") or os.environ.get("KYC_OCR_LOCALE_HINT")
    if v:
        return str(v)
    c = (country or "").strip().upper()
    if c == "IR":
        return "fa"
    return "en"



def _infer_ocr_kind(class_name: str, logic: Dict[str, Any]) -> str:
    # Support both: logic["type"] and logic["ocr_kind"]
    t = (logic.get("type") or logic.get("ocr_kind") or "").strip().lower()
    if t in {"none", "text", "numeric", "date", "mrz"}:
        return t

    cn = (class_name or "").lower()
    if "mrz" in cn:
        return "mrz"
    if "date" in cn:
        return "date"
    if "number" in cn or " no" in cn or cn.endswith("no"):
        return "numeric"
    if "qr" in cn or "barcode" in cn:
        return "none"
    if "photo" in cn or "logo" in cn or "document_body" in cn or cn == "document_body":
        return "none"
    return "text"


def _default_match_for_key(key: str) -> Tuple[str, bool, bool, float]:
    if key in {"id_number", "passport_no"}:
        return "exact", True, True, 2.2
    if key in {"first_name", "last_name", "full_name"}:
        return "fuzzy", True, True, 1.6
    return "optional", False, False, 1.0


def _constraints_from_logic(logic: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extracts constraint-like keys from validation_logic.

    Supports common aliases:
      - len -> length
      - min_len -> min_length
      - max_len -> max_length

    Also passes through checksum-related params (weights/mod/check_index/check_rule) so the engine
    can run deterministic validators without hard-coding per-country logic.
    """
    c: Dict[str, Any] = {}

    # --- length aliases ---
    if "length" in logic:
        c["length"] = logic.get("length")
    elif "len" in logic:
        c["length"] = logic.get("len")

    if "min_length" in logic:
        c["min_length"] = logic.get("min_length")
    elif "min_len" in logic:
        c["min_length"] = logic.get("min_len")

    if "max_length" in logic:
        c["max_length"] = logic.get("max_length")
    elif "max_len" in logic:
        c["max_length"] = logic.get("max_len")

    for k in ("prefix", "regex", "must_parse"):
        if k in logic:
            c[k] = logic[k]

    # checksum / validator parameters (engine will ignore if validator doesn't use them)
    for k in ("weights", "weight_list", "mod", "check_index", "check_rule"):
        if k in logic:
            c[k] = logic[k]

    return c


def _to_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, float)):
        return bool(x)
    if isinstance(x, str):
        s = x.strip().lower()
        if s in {"1", "true", "yes", "y", "required", "on", "enabled"}:
            return True
        if s in {"0", "false", "no", "n", "off", "disabled", "none", "optional"}:
            return False
    return None


def _extract_face_policy_on_off(variant: Dict[str, Any], validation_logic: Dict[str, Any]) -> Dict[str, Any]:
    """
    New rule: Face is either ENABLED (required) or DISABLED. No optional mode.

    Reads in order:
      1) validation_logic['__face__' or 'face_match' or '_face' or 'selfie']
      2) variant['face_match' or 'face_policy' or 'selfie_policy' ...]
      3) flat keys on variant
    """
    out: Dict[str, Any] = {}

    # variant-level nested
    for k in ("face_match", "face_policy", "selfie", "selfie_policy"):
        v = _coerce_json(variant.get(k))
        if isinstance(v, dict):
            out.update(v)

    # validation_logic-level nested
    for k in ("__face__", "face_match", "_face", "selfie"):
        v = _coerce_json(validation_logic.get(k)) if isinstance(validation_logic, dict) else None
        if isinstance(v, dict):
            out.update(v)

    # flat fallback
    if "require_face_match" in variant and ("enabled" not in out and "mode" not in out and "required" not in out):
        out["enabled"] = variant.get("require_face_match")
    if "face_match_threshold" in variant and "threshold" not in out:
        out["threshold"] = variant.get("face_match_threshold")
    if "face_metric" in variant and "metric" not in out:
        out["metric"] = variant.get("face_metric")

    return out


def build_doc_config_payload_from_variant(
    *,
    variant: Dict[str, Any],
    class_id_to_name: Dict[int, str],
    doc_type: str,
    country: str,
    defaults: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    defaults = defaults or {}

    # External, country-level identity format/validator defaults (JSON file).
    # This keeps checksum/format rules out of the DB variants and makes them centrally maintainable.
    identity_rules_path = str(defaults.get("identity_rules_path") or os.environ.get("KYC_IDENTITY_RULES_PATH") or "")
    strict_identity = _to_bool(defaults.get("strict_identity_registry") or os.environ.get("KYC_STRICT_IDENTITY_REGISTRY")) or False

    required_steps = _coerce_json(variant.get("required_steps")) or []
    if isinstance(required_steps, str):
        required_steps = _coerce_json(required_steps) or []
    if not isinstance(required_steps, list):
        required_steps = []
    required_ids = {int(x) for x in required_steps if str(x).isdigit()}

    validation_logic = _coerce_json(variant.get("validation_logic")) or {}
    if isinstance(validation_logic, str):
        validation_logic = _coerce_json(validation_logic) or {}
    if not isinstance(validation_logic, dict):
        validation_logic = {}

    # allow marking required inside validation_logic too
    for _cid_k, _logic in list(validation_logic.items()):
        try:
            _cid = int(_cid_k)
        except Exception:
            continue
        if not isinstance(_logic, dict):
            continue
        req = _to_bool(_logic.get("required") or _logic.get("require") or _logic.get("must_exist"))
        if req:
            required_ids.add(_cid)

    ids = set(required_ids)
    for k in validation_logic.keys():
        try:
            ids.add(int(k))
        except Exception:
            continue

    rules: List[Dict[str, Any]] = []

    for cid in sorted(ids):
        class_name = class_id_to_name.get(cid)
        if not class_name:
            continue

        key = _CANON_KEYS.get(class_name, class_name.strip().lower().replace(" ", "_"))

        logic = validation_logic.get(str(cid)) or validation_logic.get(cid) or {}
        if not isinstance(logic, dict):
            logic = {}

        # Merge country-level defaults for this field (if present).
        # Variant-level logic wins over defaults.
        try:
            identity_defaults = get_identity_defaults(
                country=country,
                doc_type=doc_type,
                field_key=key,
                path=identity_rules_path,
                strict=bool(strict_identity),
            )
        except Exception as e:
            logging.getLogger(__name__).warning("identity_defaults_load_failed country=%s doc_type=%s key=%s err=%s", country, doc_type, key, e)
            identity_defaults = None

        if isinstance(identity_defaults, dict) and identity_defaults:
            merged = dict(identity_defaults)
            merged.update(logic)
            logic = merged

        ocr_kind = _infer_ocr_kind(class_name, logic)

        match_type, match_gate, must_match, weight = _default_match_for_key(key)

        match_type = str(logic.get("match_type", match_type))
        match_gate = bool(logic.get("match_gate", match_gate))
        must_match = bool(logic.get("must_match", must_match))
        weight = float(logic.get("weight", weight))

        min_det_conf = float(logic.get("min_det_conf", defaults.get("min_det_conf", 0.35 if cid in required_ids else 0.25)))
        # In LLM-OCR version, this is LLM confidence threshold
        min_ocr_conf = float(logic.get("min_ocr_conf", defaults.get("min_ocr_conf", 0.20)))

        constraints = _constraints_from_logic(logic)

        expected_len = int(
            (
                logic.get("expected_len")
                or logic.get("length")
                or logic.get("len")
                or (constraints.get("length") if isinstance(constraints, dict) else None)
                or defaults.get("expected_len", 0)
                or 0
            )
        )

        validator = logic.get("validator")

        if country.upper() == "IR" and key == "id_number":
            validator = validator or "iran_national_code"
            expected_len = expected_len or 10

        max_candidates = int(logic.get("max_candidates", defaults.get("max_candidates", 2 if key in _CORE_KEYS else 1)) or 1)

        input_aliases = logic.get("input_aliases")
        if isinstance(input_aliases, str):
            input_aliases = [input_aliases]
        if input_aliases is not None and not isinstance(input_aliases, list):
            input_aliases = None

        rules.append({
            "class_names": [class_name],
            "key": key,
            "required": bool(cid in required_ids),

            "min_det_conf": min_det_conf,
            "min_ocr_conf": min_ocr_conf,

            "ocr_kind": ocr_kind,

            "match_type": match_type,
            "match_threshold": float(logic.get("match_threshold", defaults.get("match_threshold", 0.85))),
            "expected_len": expected_len,

            "validator": validator,
            "constraints": constraints,

            "weight": weight,
            "must_match": must_match,
            "match_gate": match_gate,

            "max_candidates": max_candidates,
            "input_aliases": input_aliases,
        })

    variant_id = variant.get("id") or variant.get("rule_id") or variant.get("variant_id")
    variant_name = variant.get("variant_name") or f"{country.lower()}_{doc_type.lower()}_{variant_id}"

    min_count = defaults.get("min_detected_fields_count")
    if min_count is None:
        min_count = len(required_ids) if required_ids else None

    # ----------- Face policy (ON/OFF) -----------
    face_policy = _extract_face_policy_on_off(variant, validation_logic)

    mode = str(face_policy.get("mode") or "").strip().lower()
    enabled_flag = _to_bool(face_policy.get("enabled") or face_policy.get("required") or face_policy.get("require_face_match"))

    require_face_match = bool(defaults.get("require_face_match", False))

    # mode mapping: required/on/enabled -> True ; disabled/off/none/optional -> False
    if mode in {"required", "require", "mandatory", "on", "enabled"}:
        require_face_match = True
    elif mode in {"disabled", "none", "off", "optional"}:
        require_face_match = False
    elif enabled_flag is not None:
        require_face_match = bool(enabled_flag)

    face_metric = str(face_policy.get("metric") or defaults.get("face_metric", "score01"))
    face_match_threshold = float(face_policy.get("threshold", defaults.get("face_match_threshold", 0.75)))

    payload = {
        "doc_id": f"{country.lower()}_{doc_type.lower()}_{variant_name}",
        "variant_id": variant_id,
        "variant_name": variant_name,
        "country": country,
        "doc_type": doc_type,
        "ocr_locale_hint": _infer_ocr_locale_hint(country, defaults),

        "rules": rules,
        "min_detected_fields_count": min_count,

        "approve_min_coverage": float(defaults.get("approve_min_coverage", 1.0)),
        "approve_min_extraction": float(defaults.get("approve_min_extraction", 0.78)),
        "approve_min_match_core": float(defaults.get("approve_min_match_core", 0.92)),
        "approve_min_match_all": float(defaults.get("approve_min_match_all", 0.0)),

        "review_min_coverage": float(defaults.get("review_min_coverage", 0.75)),
        "reject_below_coverage": float(defaults.get("reject_below_coverage", 0.50)),

        "approve_no_input_extra_buffer": float(defaults.get("approve_no_input_extra_buffer", 0.10)),

        "require_face_match": bool(require_face_match),
        "face_metric": str(face_metric),
        "face_match_threshold": float(face_match_threshold),

        "swap_pairs": defaults.get("swap_pairs", []),
        "enable_name_swap": bool(defaults.get("enable_name_swap", True)),
        "name_swap_margin": float(defaults.get("name_swap_margin", 0.06)),
    }
    return payload


def _merge_union_payload(payloads: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Creates a "union" payload to run ONE extraction (YOLO crops + LLM OCR) for all fields.
    For each key, we keep:
      - min_det_conf: minimum across variants (to not miss)
      - min_ocr_conf: minimum across variants
      - max_candidates: maximum across variants
      - weight/match settings don't matter for extraction, but kept from the first rule
    """
    by_key: Dict[str, Dict[str, Any]] = {}
    for p in payloads:
        for r in (p.get("rules") or []):
            if not isinstance(r, dict):
                continue
            key = str(r.get("key") or "")
            if not key:
                continue
            if key not in by_key:
                rr = dict(r)
                rr["required"] = False  # union extraction doesn't need "required"
                by_key[key] = rr
            else:
                cur = by_key[key]
                cur["min_det_conf"] = float(min(float(cur.get("min_det_conf", 0.25)), float(r.get("min_det_conf", 0.25))))
                cur["min_ocr_conf"] = float(min(float(cur.get("min_ocr_conf", 0.20)), float(r.get("min_ocr_conf", 0.20))))
                cur["max_candidates"] = int(max(int(cur.get("max_candidates", 1) or 1), int(r.get("max_candidates", 1) or 1)))
                # keep class_names/ocr_kind/constraints from first (assumed same)

    # Provide a stable doc_id for union; real doc_id is per variant in scoring anyway.
    base = payloads[0] if payloads else {}
    union_payload = {
        "doc_id": "union_extraction",
        "country": str(base.get("country") or ""),
        "doc_type": str(base.get("doc_type") or ""),
        "ocr_locale_hint": str(base.get("ocr_locale_hint") or base.get("locale_hint") or ""),
        "rules": list(by_key.values()),
        # extraction scoring is per variant, so thresholds don't matter here
        "min_detected_fields_count": None,
        "require_face_match": False,
    }
    return union_payload


@dataclass
class VariantCandidate:
    variant_name: str
    payload_doc_id: str
    decision: str
    scores: Dict[str, Any]
    reasons: List[str]
    sort_key: Tuple


def _decision_rank(d: str) -> int:
    d = (d or "").upper()
    if d == "APPROVE":
        return 3
    if d == "REVIEW":
        return 2
    if d == "REJECT":
        return 1
    return 0


def _candidate_sort_key(decision: str, scores: Dict[str, Any]) -> Tuple:
    s = scores or {}
    mc = s.get("match_core")
    ma = s.get("match_all")
    ex = s.get("extraction", 0.0)
    cov = s.get("coverage", 0.0)
    face = s.get("face")
    final_score = s.get("final_score")
    return (
        _decision_rank(decision),
        float(final_score if final_score is not None else -1.0),
        float(mc if mc is not None else -1.0),
        float(ma if ma is not None else -1.0),
        float(ex),
        float(cov),
        float(face if face is not None else -1.0),
    )


def run_kyc_auto_variant(
        *,
        kyc_engine_module,
        php_endpoint_url: str,
        model_path: str,
        doc_type: str,
        country: str,
        doc_image_path: str,

        # LLM OCR hook (REQUIRED)
        llm_ocr: Callable[[Dict[str, Any]], Dict[str, Any]],

        selfie_image_path: Optional[str] = None,
        user_input: Optional[Dict[str, str]] = None,
        debug: bool = False,
        defaults: Optional[Dict[str, Any]] = None,
        device: str = "0",
        yolo_conf: float = 0.25,
        yolo_iou: float = 0.6,
        max_det: int = 200,
) -> Dict[str, Any]:
    defaults = defaults or {}
    user_input = user_input or {}

    # 1) Fetch policy FIRST
    payload = _http_get_json(
        php_endpoint_url,
        params={"type": doc_type, "country": country},
        timeout_s=float(defaults.get("policy_timeout_s", 12.0)),
        retries=int(defaults.get("policy_retries", 2)),
    )

    # 2) Must be success
    if str(payload.get("status", "")).lower() != "success":
        out = {
            "selected_variant_doc_id": None,
            "selected_variant": None,
            "decision": "REVIEW",
            "scores": {},
            "reasons": [f"policy_fetch_failed:{payload.get('message')}"],
            "per_field": {},
        }
        if debug:
            out["raw_policy"] = payload
        return out

    # 3) Resolve YOLO model path from server config
    server_model_path = _extract_yolo_model_path(payload)
    if server_model_path:
        model_path = server_model_path  # override cli/fallback

    if not model_path:
        out = {
            "selected_variant_doc_id": None,
            "selected_variant": None,
            "decision": "REVIEW",
            "scores": {},
            "reasons": ["yolo_model_path_missing_in_policy"],
            "per_field": {},
        }
        if debug:
            out["raw_policy"] = payload
        return out

    if debug:
        print("[KYC] yolo_model_path(from server):", server_model_path)
        print("[KYC] yolo_model_path(used):", model_path)

    # 4) NOW load YOLO engine with correct model
    engine = kyc_engine_module.get_engine(model_path)
    class_id_to_name = dict(engine.class_names)

    # 5) Get variants list
    variants = _find_rules_list(payload)
    if not variants:
        out = {
            "selected_variant_doc_id": None,
            "selected_variant": None,
            "decision": "REVIEW",
            "scores": {},
            "reasons": ["no_variants_found"],
            "per_field": {},
        }
        if debug:
            out["raw_policy"] = payload
        return out

    # 6) Build configs for each variant
    variant_payloads: List[Dict[str, Any]] = []
    variant_cfgs: List[Any] = []
    raw_variants: List[Dict[str, Any]] = []

    for v in variants:
        if not isinstance(v, dict):
            continue

        doc_cfg_payload = build_doc_config_payload_from_variant(
            variant=v,
            class_id_to_name=class_id_to_name,
            doc_type=doc_type,
            country=country,
            defaults=defaults,
        )
        variant_payloads.append(doc_cfg_payload)
        variant_cfgs.append(kyc_engine_module.doc_config_from_payload(doc_cfg_payload))
        raw_variants.append(v)

    if not variant_payloads:
        out = {
            "selected_variant_doc_id": None,
            "selected_variant": None,
            "decision": "REVIEW",
            "scores": {},
            "reasons": ["no_variant_executed"],
            "per_field": {},
        }
        if debug:
            out["raw_policy"] = payload
        return out

    # 7) Detect ONCE
    dets = engine.detect(doc_image_path, conf=yolo_conf, iou=yolo_iou, device=device, max_det=max_det)

    # 8) Extract ONCE (union fields) via LLM OCR
    union_payload = _merge_union_payload(variant_payloads)
    union_cfg = kyc_engine_module.doc_config_from_payload(union_payload)

    pack = engine.extract_with_llm(
        doc_image_path,
        union_cfg,
        dets,
        llm_ocr=llm_ocr,
        user_input=None,  # do NOT bias OCR with user input
        debug=debug
    )

    base_fields = pack["fields"]
    base_internals = pack["internals"]

    # 9) Face pack computed once if ANY variant requires it
    face_pack_global = None
    any_face_required = any(getattr(cfg, "require_face_match", False) for cfg in variant_cfgs)
    if any_face_required:
        if selfie_image_path:
            quad = base_internals.get("quads", {}).get("doc_photo") or base_internals.get("quads", {}).get("photo")
            face_pack_global = kyc_engine_module.compute_face_pack(
                doc_image_path=doc_image_path,
                selfie_image_path=selfie_image_path,
                doc_photo_quad=quad,
            )
        else:
            face_pack_global = {"score01": None, "cosine": None, "reason": "selfie_missing", "details": None}

    # 10) Score each variant
    best_key: Optional[Tuple] = None
    best_result: Optional[Dict[str, Any]] = None
    best_variant_raw: Optional[Dict[str, Any]] = None
    cand_list: List[VariantCandidate] = []

    for idx, cfg in enumerate(variant_cfgs):
        doc_cfg_payload = variant_payloads[idx]
        vraw = raw_variants[idx]

        fields = copy.deepcopy(base_fields)
        internals = copy.deepcopy(base_internals)

        kyc_engine_module.apply_config_swaps(fields, internals, cfg)
        kyc_engine_module.apply_name_swap_if_needed(fields, internals, cfg, user_input)

        face_pack = face_pack_global if cfg.require_face_match else None

        scoring = engine.score(cfg, fields, user_input=user_input, face_pack=face_pack, debug=debug)
        decision = str(scoring.get("decision") or "")
        scores = scoring.get("scores") or {}
        sk = _candidate_sort_key(decision, scores)

        res = {
            "doc_id": str(doc_cfg_payload.get("doc_id")),
            "decision": decision,
            "scores": scores,
            "reasons": scoring.get("reasons") or [],
            "per_field": scoring.get("per_field") or {},
            "selfie_mode": ("required" if cfg.require_face_match else "disabled"),
            "next_steps": (["upload_selfie_required"] if (cfg.require_face_match and not selfie_image_path) else []),
        }

        cand_list.append(VariantCandidate(
            variant_name=str(doc_cfg_payload.get("variant_name")),
            payload_doc_id=str(doc_cfg_payload.get("doc_id")),
            decision=decision,
            scores=scores,
            reasons=res.get("reasons") or [],
            sort_key=sk,
        ))

        if best_key is None or sk > best_key:
            best_key = sk
            best_result = res
            best_variant_raw = vraw

    if best_result is None:
        out = {
            "selected_variant_doc_id": None,
            "selected_variant": None,
            "decision": "REVIEW",
            "scores": {},
            "reasons": ["no_variant_executed"],
            "per_field": {},
        }
        if debug:
            out["raw_policy"] = payload
        return out

    out = {
        "selected_variant_doc_id": str(best_result.get("doc_id")),
        "selected_variant": str(best_result.get("doc_id")),
        "decision": best_result.get("decision"),
        "scores": best_result.get("scores"),
        "reasons": best_result.get("reasons"),
        "per_field": best_result.get("per_field"),
        "selfie_mode": best_result.get("selfie_mode"),
        "next_steps": best_result.get("next_steps"),
    }

    if debug:
        out["variant_candidates"] = [c.__dict__ for c in sorted(cand_list, key=lambda x: x.sort_key, reverse=True)]
        out["best_variant_raw_rule"] = best_variant_raw
        out["raw_policy"] = payload
        out["detections"] = dets
        out["fields"] = base_fields
        out["internals"] = base_internals
        out["face"] = face_pack_global
        out["llm_bundle"] = pack.get("llm_bundle")
        out["llm_response_raw"] = pack.get("llm_response_raw")

    return out
