# identity_registry.py
# ------------------------------------------------------------
# Central, country-level identity format + checksum rules.
#
# Why this exists:
# - Your document variants (DB/CSV) describe *layout* (YOLO classes, required fields, match rules).
# - Identity numbers (national ID, passport numbers, etc.) have deterministic format/checksum rules
#   that should NOT be embedded in LLM prompts and should NOT be duplicated per-variant.
# - This registry lets you add/adjust rules per country in one place.

from __future__ import annotations

import json
import os
import threading
from copy import deepcopy
from typing import Any, Dict, Optional

# Keys we allow to flow into validation_logic merge (keep surface area small & safe).
_ALLOWED_KEYS = {
    # OCR kind
    "type", "ocr_kind",

    # core validator selector
    "validator",

    # common constraints (policy_provider extracts these)
    "length", "len", "min_length", "min_len", "max_length", "max_len",
    "prefix", "regex", "must_parse",

    # checksum params (passed through into constraints for validator functions)
    "weights", "weight_list", "mod", "check_index", "check_rule",

    # misc optional
    "expected_len",
}

# If strict mode is enabled, missing rules for these keys will produce a failing validator.
_STRICT_DEFAULT_KEYS = {"id_number", "passport_no", "license_number", "driver_license_no"}

_CACHE_LOCK = threading.Lock()
_CACHE: Dict[str, Any] = {
    "path": None,
    "mtime": None,
    "data": None,  # parsed dict
}


def _pick_default_path() -> str:
    env = os.environ.get("KYC_IDENTITY_RULES_PATH")
    if env:
        return env

    here = os.path.dirname(os.path.abspath(__file__))
    candidates = [
        os.path.join(here, "identity_rules.json"),
        os.path.join(here, "identity_registry.json"),
        os.path.join(here, "country_identity_rules.json"),
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    # default, even if it doesn't exist yet (caller may create it later)
    return candidates[0]


def _file_mtime(path: str) -> Optional[float]:
    try:
        return os.path.getmtime(path)
    except Exception:
        return None


def load_identity_registry(path: Optional[str] = None) -> Dict[str, Any]:
    """
    Loads + caches the identity registry JSON. Auto-reloads on file mtime change.

    Schema (suggested):
    {
      "schema_version": 1,
      "countries": {
        "IR": {
          "kinds": {
            "national_id": { "type":"numeric", "length":10, "validator":"iran_national_code", "regex":"^[0-9]{10}$" }
          },
          "doc_types": {
            "id_card": { "fields": { "id_number": { "use": "national_id" } } },
            "driver_license": { "fields": { "id_number": { "use": "national_id" } } }
          }
        }
      }
    }
    """
    p = (path or "").strip() or _pick_default_path()
    mtime = _file_mtime(p)

    with _CACHE_LOCK:
        if _CACHE["data"] is not None and _CACHE["path"] == p and _CACHE["mtime"] == mtime:
            return _CACHE["data"]

        if not os.path.exists(p):
            data = {"schema_version": 1, "countries": {}}
            _CACHE.update({"path": p, "mtime": mtime, "data": data})
            return data

        with open(p, "r", encoding="utf-8-sig") as f:
            raw = f.read()

        try:
            data = json.loads(raw)
        except Exception as e:
            raise RuntimeError(f"Identity registry JSON parse failed: path={p} err={e}") from e

        if not isinstance(data, dict):
            raise RuntimeError(f"Identity registry must be a JSON object: path={p}")

        if "countries" not in data or not isinstance(data.get("countries"), dict):
            # tolerate a flat map: { "IR": {...} }
            if all(isinstance(k, str) and isinstance(v, dict) for k, v in data.items()):
                data = {"schema_version": 1, "countries": data}
            else:
                data = {"schema_version": int(data.get("schema_version") or 1), "countries": {}}

        _CACHE.update({"path": p, "mtime": mtime, "data": data})
        return data


def _sanitize_spec(d: Dict[str, Any]) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    for k, v in (d or {}).items():
        if k in _ALLOWED_KEYS:
            out[k] = v
    return out


def _resolve_field_spec(country_entry: Dict[str, Any], doc_type: str, field_key: str) -> Optional[Dict[str, Any]]:
    doc_type_l = (doc_type or "").strip().lower()
    field_key_s = (field_key or "").strip()

    doc_types = country_entry.get("doc_types")
    if not isinstance(doc_types, dict):
        doc_types = {}

    dt_entry = doc_types.get(doc_type_l) or doc_types.get(doc_type) or {}
    if not isinstance(dt_entry, dict):
        dt_entry = {}

    fields = dt_entry.get("fields")
    if fields is None:
        # allow shorthand: doc_types[id_card][id_number] = {...}
        fields = dt_entry
    if not isinstance(fields, dict):
        return None

    spec = fields.get(field_key_s)
    if spec is None:
        return None

    # If spec is string, treat as {"use": "<kind>"}
    if isinstance(spec, str):
        spec = {"use": spec}

    if not isinstance(spec, dict):
        return None

    # Resolve {"use": "<kind>"} through country_entry.kinds
    use = spec.get("use")
    if use:
        kinds = country_entry.get("kinds") or {}
        if isinstance(kinds, dict):
            base = kinds.get(str(use))
            if isinstance(base, str):
                base = kinds.get(base)
            if isinstance(base, dict):
                merged = dict(base)
                # allow overrides next to "use"
                for k, v in spec.items():
                    if k == "use":
                        continue
                    merged[k] = v
                return merged

    return spec


def get_identity_defaults(
    *,
    country: str,
    doc_type: str,
    field_key: str,
    path: str = "",
    strict: bool = False,
    strict_keys: Optional[set] = None,
) -> Optional[Dict[str, Any]]:
    """
    Returns a dict of defaults to be merged into validation_logic for a given (country, doc_type, field_key).

    - If no defaults exist: returns None (unless strict=True and field_key is in strict_keys)
    - If defaults exist: returns a sanitized dict containing only allowed keys.
    """
    data = load_identity_registry(path)
    countries = data.get("countries") if isinstance(data, dict) else None
    if not isinstance(countries, dict):
        countries = {}

    cc = (country or "").strip().upper()
    entry = countries.get(cc) or countries.get(cc.lower())
    if not isinstance(entry, dict):
        if strict and ((field_key or "") in (strict_keys or _STRICT_DEFAULT_KEYS)):
            return {"validator": "missing_identity_rule"}
        return None

    spec = _resolve_field_spec(entry, doc_type, field_key)
    if not isinstance(spec, dict):
        if strict and ((field_key or "") in (strict_keys or _STRICT_DEFAULT_KEYS)):
            return {"validator": "missing_identity_rule"}
        return None

    return _sanitize_spec(deepcopy(spec))
