"""
MASHBiospecimenCoC-Kor — 순수 로직 모듈
========================================
Streamlit 컨텍스트 없이 import/호출 가능한 순수 함수 모음.
window 준수, custody 사슬 완전성, shipment QC, pathology turnaround, block 재고 계산.

본 도구는 연구·운영 보조용 참고 도구이며 실제 임상시험 규제 의사결정을 대체하지 않는다.
"""

from __future__ import annotations

import os
from datetime import datetime

import pandas as pd

DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")

# custody 사슬에서 물리 검체가 거쳐야 하는 표준 단계(IATA/GCP 사슬 개념)
EXPECTED_CUSTODY_STEPS = [
    "collection",
    "site_storage",
    "shipment_dispatch",
    "central_receipt",
]


# ---------------------------------------------------------------------------
# 데이터 로드
# ---------------------------------------------------------------------------
def load_csv(name: str, data_dir: str | None = None) -> pd.DataFrame:
    data_dir = data_dir or DATA_DIR
    return pd.read_csv(os.path.join(data_dir, name))


def load_all(data_dir: str | None = None) -> dict[str, pd.DataFrame]:
    return {
        "visit_schedule": load_csv("visit_schedule.csv", data_dir),
        "manifest": load_csv("specimen_manifest.csv", data_dir),
        "custody": load_csv("custody_log.csv", data_dir),
        "shipment": load_csv("shipment_log.csv", data_dir),
        "pathology": load_csv("pathology_turnaround.csv", data_dir),
    }


# ---------------------------------------------------------------------------
# 기능 1: schedule window 준수
# ---------------------------------------------------------------------------
def is_in_window(collected_date: str, window_start: str, window_end: str) -> bool:
    """채취일이 [window_start, window_end] 안에 있는지. 빈 값이면 False."""
    if not collected_date or pd.isna(collected_date):
        return False
    c = datetime.strptime(str(collected_date), "%Y-%m-%d")
    s = datetime.strptime(str(window_start), "%Y-%m-%d")
    e = datetime.strptime(str(window_end), "%Y-%m-%d")
    return s <= c <= e


def schedule_compliance_summary(manifest: pd.DataFrame) -> dict:
    """due 도래 검체 대비 채취율, window 위반, missed 집계."""
    m = manifest.copy()
    due = m[m["status"].isin(["collected", "missed"])]
    n_due = len(due)
    n_collected = int((due["status"] == "collected").sum())
    n_missed = int((due["status"] == "missed").sum())
    n_violation = int((m["window_violation"] == "yes").sum())
    return {
        "total_rows": len(m),
        "due_rows": n_due,
        "collected": n_collected,
        "missed": n_missed,
        "window_violations": n_violation,
        "collection_rate": round(n_collected / n_due, 4) if n_due else 0.0,
    }


def window_alert_list(manifest: pd.DataFrame) -> pd.DataFrame:
    """alert 대상: missed 또는 window 위반 검체."""
    m = manifest.copy()
    mask = (m["status"] == "missed") | (m["window_violation"] == "yes")
    cols = [
        "patient_id", "site_id", "visit", "specimen_type",
        "target_due_date", "window_start", "window_end",
        "collected_date", "status", "window_violation",
    ]
    return m.loc[mask, cols].sort_values(["site_id", "patient_id", "visit"])


# ---------------------------------------------------------------------------
# 기능 2: chain-of-custody 완전성
# ---------------------------------------------------------------------------
def custody_completeness(custody: pd.DataFrame) -> pd.DataFrame:
    """검체(uid)별 custody 단계 누락 여부. VCTE(collection만)는 정상으로 본다."""
    rows = []
    for uid, grp in custody.groupby("specimen_uid"):
        spec = grp["specimen_type"].iloc[0]
        steps = set(grp["step"])
        if spec == "VCTE":
            expected = {"collection"}
        else:
            expected = set(EXPECTED_CUSTODY_STEPS)
        missing = sorted(expected - steps)
        rows.append({
            "specimen_uid": uid,
            "patient_id": grp["patient_id"].iloc[0],
            "site_id": grp["site_id"].iloc[0],
            "specimen_type": spec,
            "n_steps": len(grp),
            "missing_steps": "|".join(missing) if missing else "",
            "chain_complete": "yes" if not missing else "no",
        })
    return pd.DataFrame(rows)


# ---------------------------------------------------------------------------
# 기능 3: cold-chain shipment QC
# ---------------------------------------------------------------------------
def shipment_qc_summary(shipment: pd.DataFrame) -> dict:
    s = shipment.copy()
    n = len(s)
    rejected = int((s["qc_result"] == "rejected").sum())
    excursion = int((s["temp_excursion"] == "yes").sum())
    recollect = int((s["recollection_triggered"] == "yes").sum())
    reasons = (
        s.loc[s["reject_reason"] != "none", "reject_reason"]
        .value_counts()
        .to_dict()
    )
    return {
        "total_shipments": n,
        "rejected": rejected,
        "accept_rate": round((n - rejected) / n, 4) if n else 0.0,
        "temp_excursions": excursion,
        "recollections_triggered": recollect,
        "reject_reasons": reasons,
        "median_transit_hours": float(s["transit_hours"].median()) if n else 0.0,
    }


def recollection_trigger_list(shipment: pd.DataFrame) -> pd.DataFrame:
    s = shipment.copy()
    cols = [
        "shipment_id", "patient_id", "site_id", "specimen_type", "courier",
        "transit_hours", "max_temp_observed_c", "temp_excursion",
        "qc_result", "reject_reason", "recollection_triggered",
    ]
    return s.loc[s["recollection_triggered"] == "yes", cols].sort_values(
        ["site_id", "patient_id"]
    )


# ---------------------------------------------------------------------------
# 기능 4: central pathology turnaround
# ---------------------------------------------------------------------------
def turnaround_summary(pathology: pd.DataFrame, target_tat: int = 14) -> dict:
    p = pathology.copy()
    completed = p[p["read_status"] == "read_complete"].copy()
    pending = int((p["read_status"] == "pending").sum())
    if len(completed):
        tats = pd.to_numeric(completed["turnaround_days"], errors="coerce").dropna()
        median_tat = float(tats.median()) if len(tats) else 0.0
        breach = int((tats > target_tat).sum())
    else:
        median_tat = 0.0
        breach = 0
    reread = int((p["reread_required"] == "yes").sum())
    return {
        "total_blocks": len(p),
        "read_complete": len(completed),
        "pending_backlog": pending,
        "median_turnaround_days": median_tat,
        "tat_breaches": breach,
        "target_tat_days": target_tat,
        "reread_required": reread,
    }


def pathology_backlog_list(pathology: pd.DataFrame) -> pd.DataFrame:
    p = pathology.copy()
    cols = [
        "block_id", "patient_id", "site_id", "central_receipt_date",
        "read_complete_date", "turnaround_days", "read_status",
        "tat_breach", "reread_required",
    ]
    mask = (p["read_status"] == "pending") | (p["tat_breach"] == "yes") | (
        p["reread_required"] == "yes"
    )
    return p.loc[mask, cols].sort_values(["read_status", "site_id"])


# ---------------------------------------------------------------------------
# 기능 5: block 재고 · 재채취 대장 · evaluable 손실 위험
# ---------------------------------------------------------------------------
def block_inventory(pathology: pd.DataFrame) -> pd.DataFrame:
    """잔여 block 슬라이드(residual_sections) 재고."""
    p = pathology.copy()
    cols = ["block_id", "patient_id", "site_id", "residual_sections", "read_status"]
    return p[cols].sort_values("residual_sections")


def recollection_register(manifest: pd.DataFrame, shipment: pd.DataFrame) -> pd.DataFrame:
    """재채취 필요 환자 대장: missed 검체 + QC reject로 재채취 trigger된 검체 통합."""
    missed = manifest[manifest["status"] == "missed"][
        ["patient_id", "site_id", "visit", "specimen_type"]
    ].copy()
    missed["reason"] = "missed_collection"

    rej = shipment[shipment["recollection_triggered"] == "yes"][
        ["patient_id", "site_id", "specimen_type", "reject_reason"]
    ].copy()
    rej = rej.rename(columns={"reject_reason": "reason"})
    rej["visit"] = "(shipment_qc)"
    rej = rej[["patient_id", "site_id", "visit", "specimen_type", "reason"]]

    reg = pd.concat([missed, rej], ignore_index=True)
    return reg.sort_values(["site_id", "patient_id"])


def evaluable_loss_risk(manifest: pd.DataFrame, shipment: pd.DataFrame,
                        pathology: pd.DataFrame) -> pd.DataFrame:
    """
    evaluable 손실 위험 환자: 1차 조직학 종결점(LiverBiopsy) 관점.
    위험 요인 - baseline 또는 wk52 biopsy가 missed / qc rejected / pathology pending 중 하나라도 해당.
    """
    risk = {}

    def add(pid, site, factor):
        if pid not in risk:
            risk[pid] = {"patient_id": pid, "site_id": site, "risk_factors": set()}
        risk[pid]["risk_factors"].add(factor)

    bio = manifest[manifest["specimen_type"] == "LiverBiopsy"]
    for _, r in bio[bio["status"] == "missed"].iterrows():
        add(r["patient_id"], r["site_id"], f"missed_biopsy_{r['visit']}")
    for _, r in bio[bio["window_violation"] == "yes"].iterrows():
        if r["status"] == "collected":
            add(r["patient_id"], r["site_id"], f"window_violation_{r['visit']}")

    rej_bio = shipment[
        (shipment["specimen_type"] == "LiverBiopsy") &
        (shipment["qc_result"] == "rejected")
    ]
    for _, r in rej_bio.iterrows():
        add(r["patient_id"], r["site_id"], f"qc_reject_{r['reject_reason']}")

    pend = pathology[pathology["read_status"] == "pending"]
    for _, r in pend.iterrows():
        add(r["patient_id"], r["site_id"], "pathology_pending")

    rows = []
    for pid, d in risk.items():
        rows.append({
            "patient_id": pid,
            "site_id": d["site_id"],
            "n_risk_factors": len(d["risk_factors"]),
            "risk_factors": "|".join(sorted(d["risk_factors"])),
        })
    out = pd.DataFrame(rows)
    if len(out):
        out = out.sort_values("n_risk_factors", ascending=False)
    return out
