"""
DMTrialSupplyChain-Kor — 핵심 계산 엔진 (순수 로직, Streamlit 비의존).

5개 기능의 계산 로직을 함수 단위로 분리하여 단위 테스트가 가능하도록 함.
모든 함수는 pandas DataFrame 입출력. 외부 네트워크/API 호출 없음.

참고 reference 기준(템플릿; 실제 규제 의사결정 대체 아님):
  - ICH GCP E6(R3) §8 (IMP accountability: dispensed/returned/destroyed balance)
  - GMP cold-chain 2-8C 보관, IATA PCR(Perishable Cargo Regulations) cold-chain 원칙
  - 주사제 안정성 budget(누적 실온 허용 시간) — IMP 라벨/안정성 데이터 기반 (데모 값)
"""

from __future__ import annotations
import pandas as pd
import numpy as np


# ===========================================================================
# 기능 1: 무작위배정 연동 resupply forecasting
# ===========================================================================
def forecast_resupply(enrollment: pd.DataFrame,
                      imp_doses_per_week: dict,
                      imp_kit_doses: dict,
                      titration_factor: float = 1.0,
                      horizon_weeks: int = 12) -> pd.DataFrame:
    """
    등록 로그 + arm별 dose/주 + kit당 dose 수 → site별·arm별·주차별 약물 소요량(kit) 예측.

    누적 등록 피험자가 매주 dose를 소비한다고 가정(생존/유지율 1.0 단순화).
    titration_factor: titration 스케줄을 반영한 초기 dose 가중(예: 1.2 = 적정기 20% 추가).

    반환 컬럼: site_id, arm, week_index, week_start, active_subjects,
              doses_needed, kits_needed
    """
    if enrollment.empty:
        return pd.DataFrame(columns=["site_id", "arm", "week_index", "week_start",
                                     "active_subjects", "doses_needed", "kits_needed"])

    enr = enrollment.copy()
    enr["week_index"] = enr["week_index"].astype(int)

    # site x arm x week 등록 카운트
    counts = (enr.groupby(["site_id", "arm", "week_index"])
                 .size().reset_index(name="new_subjects"))

    max_week = int(enr["week_index"].max()) + horizon_weeks
    rows = []
    for (site, arm), g in counts.groupby(["site_id", "arm"]):
        cum = 0
        # 주차별 신규 등록 매핑
        new_by_week = dict(zip(g["week_index"], g["new_subjects"]))
        first_week = int(enr["week_start"].iloc[0] and 0)  # 0 기준
        for w in range(0, max_week + 1):
            cum += int(new_by_week.get(w, 0))
            dpw = imp_doses_per_week.get(arm, 1)
            kit_doses = imp_kit_doses.get(arm, 1)
            # titration: 등록 후 첫 4주에 가중 (단순화)
            doses = cum * dpw * titration_factor
            kits = doses / kit_doses if kit_doses else doses
            week_start = (pd.Timestamp("2026-01-05") + pd.Timedelta(weeks=w)).strftime("%Y-%m-%d")
            rows.append({
                "site_id": site, "arm": arm, "week_index": w, "week_start": week_start,
                "active_subjects": cum,
                "doses_needed": round(doses, 1),
                "kits_needed": int(np.ceil(kits)),
            })
    return pd.DataFrame(rows)


def resupply_triggers(forecast: pd.DataFrame, sites: pd.DataFrame,
                     current_week: int = 0) -> pd.DataFrame:
    """
    예측 소요량 + site 현재고/par level/lead time → 보급 trigger·부족 위험일 산출.

    부족 위험: current_week 이후 누적 소요가 on_hand를 초과하기 시작하는 주.
    par level 도달 시 자동 resupply trigger 표시.
    """
    if forecast.empty or sites.empty:
        return pd.DataFrame()

    out = []
    site_imp = sites.set_index(["site_id", "imp"])
    for (site, arm), g in forecast.groupby(["site_id", "arm"]):
        key = (site, arm)
        if key not in site_imp.index:
            continue
        info = site_imp.loc[key]
        on_hand = int(info["on_hand_kits"])
        par = int(info["par_level_kits"])
        lead = int(info["lead_time_days"])
        g = g.sort_values("week_index")
        # 미래 주차별 소요 증가분(주 단위 신규 소비) = kits_needed diff
        g = g[g["week_index"] >= current_week].copy()
        g["weekly_consumption"] = g["kits_needed"].diff().fillna(g["kits_needed"]).clip(lower=0)
        running = on_hand
        stockout_week = None
        trigger_week = None
        for _, r in g.iterrows():
            running -= r["weekly_consumption"]
            if trigger_week is None and running <= par:
                trigger_week = int(r["week_index"])
            if stockout_week is None and running < 0:
                stockout_week = int(r["week_index"])
        out.append({
            "site_id": site, "arm": arm,
            "on_hand_kits": on_hand, "par_level_kits": par, "lead_time_days": lead,
            "resupply_trigger_week": trigger_week,
            "projected_stockout_week": stockout_week,
            "suggested_order_kits": int(max(par * 2 - on_hand, par)) if trigger_week is not None else 0,
            "risk": ("HIGH" if stockout_week is not None else
                     ("WATCH" if trigger_week is not None else "OK")),
        })
    return pd.DataFrame(out)


# ===========================================================================
# 기능 2: FEFO 만료관리 엔진
# ===========================================================================
def fefo_allocation(lots: pd.DataFrame, demand_kits_by_imp: dict,
                   as_of: pd.Timestamp | None = None) -> pd.DataFrame:
    """
    FEFO(First-Expiry-First-Out): IMP별 lot을 만료일 오름차순으로 demand에 배분.

    반환: lot별 allocated_kits, remaining_kits, days_to_expiry, fefo_rank.
    """
    if lots.empty:
        return pd.DataFrame()
    if as_of is None:
        as_of = pd.Timestamp("2026-01-05")

    df = lots.copy()
    df["expiry_date"] = pd.to_datetime(df["expiry_date"])
    df["days_to_expiry"] = (df["expiry_date"] - as_of).dt.days

    out = []
    for imp, g in df.groupby("imp"):
        g = g.sort_values("expiry_date").reset_index(drop=True)  # FEFO
        remaining_demand = int(demand_kits_by_imp.get(imp, 0))
        for rank, (_, lot) in enumerate(g.iterrows(), start=1):
            avail = int(lot["qty_kits"])
            alloc = min(avail, remaining_demand)
            remaining_demand -= alloc
            out.append({
                "lot_id": lot["lot_id"], "imp": imp,
                "expiry_date": lot["expiry_date"].strftime("%Y-%m-%d"),
                "days_to_expiry": int(lot["days_to_expiry"]),
                "fefo_rank": rank,
                "qty_kits": avail,
                "allocated_kits": alloc,
                "remaining_kits": avail - alloc,
            })
    return pd.DataFrame(out)


def expiry_waste(lots: pd.DataFrame, demand_kits_by_imp: dict,
                horizon_days: int = 180, kit_cost: float = 250000.0,
                as_of: pd.Timestamp | None = None) -> pd.DataFrame:
    """
    horizon 내 만료되며 demand로 소진되지 못하는 lot의 예상 폐기량/폐기비용 산출.

    kit_cost: kit당 비용(원, 데모 기본 250,000원).
    """
    alloc = fefo_allocation(lots, demand_kits_by_imp, as_of=as_of)
    if alloc.empty:
        return pd.DataFrame()
    waste = alloc[(alloc["remaining_kits"] > 0) &
                  (alloc["days_to_expiry"] <= horizon_days)].copy()
    waste["est_waste_cost"] = waste["remaining_kits"] * kit_cost
    return waste[["lot_id", "imp", "expiry_date", "days_to_expiry",
                  "remaining_kits", "est_waste_cost"]].reset_index(drop=True)


# ===========================================================================
# 기능 3: 온도일탈(excursion) 영향 추적
# ===========================================================================
def detect_excursions(temp_log: pd.DataFrame,
                      low_c: float = 2.0, high_c: float = 8.0,
                      budget_hours_by_imp: dict | None = None) -> pd.DataFrame:
    """
    온도 로그에서 일탈 구간(>high_c 또는 <low_c) 검출 후 lot/shipment별 집계.

    누적 일탈 시간이 IMP 안정성 budget을 초과하면 'QUARANTINE'(격리),
    이내면 'USABLE'(사용가능) 판정. budget 미정 IMP는 보수적으로 QUARANTINE.

    반환: shipment_id, lot_id, imp, site_id, n_excursion_points,
          cumulative_excursion_hours, budget_hours, disposition
    """
    if temp_log.empty:
        return pd.DataFrame()
    if budget_hours_by_imp is None:
        budget_hours_by_imp = {}

    df = temp_log.copy()
    df["timestamp"] = pd.to_datetime(df["timestamp"])
    df = df.sort_values(["shipment_id", "timestamp"])
    df["is_excursion"] = (df["temp_c"] > high_c) | (df["temp_c"] < low_c)

    out = []
    group_cols = ["shipment_id", "lot_id", "imp", "site_id"]
    for keys, g in df.groupby(group_cols):
        g = g.sort_values("timestamp")
        # 측정 간격(시간) 추정: 연속 측정 간 시간차 중앙값
        if len(g) > 1:
            deltas = g["timestamp"].diff().dropna().dt.total_seconds() / 3600.0
            step_h = float(deltas.median()) if not deltas.empty else 1.0
        else:
            step_h = 1.0
        n_exc = int(g["is_excursion"].sum())
        cum_hours = round(n_exc * step_h, 2)
        imp = keys[2]
        budget = budget_hours_by_imp.get(imp)
        if budget is None:
            disposition = "QUARANTINE"  # budget 미상 → 보수적
            budget_val = np.nan
        else:
            budget_val = float(budget)
            disposition = "QUARANTINE" if cum_hours > budget_val else "USABLE"
        out.append({
            "shipment_id": keys[0], "lot_id": keys[1], "imp": imp, "site_id": keys[3],
            "n_excursion_points": n_exc,
            "cumulative_excursion_hours": cum_hours,
            "budget_hours": budget_val,
            "max_temp_c": round(float(g["temp_c"].max()), 2),
            "min_temp_c": round(float(g["temp_c"].min()), 2),
            "disposition": disposition,
        })
    return pd.DataFrame(out)


# ===========================================================================
# 기능 4: drug accountability 원장
# ===========================================================================
def accountability_balance(acct: pd.DataFrame) -> pd.DataFrame:
    """
    ICH GCP E6(R3) accountability balance:
      shipped_to_site == dispensed + returned + destroyed + on_site_remaining

    site x imp별 거래를 집계하여 balance와 불일치(mismatch)를 산출.
    on_site_remaining = shipped - dispensed (반납/폐기 전 site 보유로 단순화).
    mismatch = shipped - (dispensed + returned + destroyed + remaining_expected)
    여기서는 reconciliation 관점: shipped - dispensed - returned - destroyed = 잔여,
    잔여가 음수면 불일치(ALERT).
    """
    if acct.empty:
        return pd.DataFrame()
    df = acct.copy()
    pivot = (df.pivot_table(index=["site_id", "imp"], columns="txn_type",
                            values="qty_kits", aggfunc="sum", fill_value=0)
               .reset_index())
    for col in ["shipped_to_site", "dispensed", "returned", "destroyed"]:
        if col not in pivot.columns:
            pivot[col] = 0
    pivot["accounted"] = pivot["dispensed"] + pivot["returned"] + pivot["destroyed"]
    # 잔여(아직 site 재고로 남아있어야 할 양) = shipped - dispensed
    pivot["expected_on_site"] = pivot["shipped_to_site"] - pivot["dispensed"]
    # reconciliation 불일치: dispensed가 shipped를 초과하거나 음수 잔여
    pivot["balance_gap"] = pivot["shipped_to_site"] - pivot["dispensed"] - pivot["returned"] - pivot["destroyed"]
    pivot["alert"] = np.where(pivot["balance_gap"] < 0, "MISMATCH",
                       np.where(pivot["expected_on_site"] < 0, "MISMATCH", "OK"))
    cols = ["site_id", "imp", "shipped_to_site", "dispensed", "returned",
            "destroyed", "accounted", "balance_gap", "alert"]
    return pivot[cols].reset_index(drop=True)


# ===========================================================================
# 기능 5: depot→site 배분 시나리오 (what-if 이산사건 시뮬레이션)
# ===========================================================================
def simulate_distribution(enrollment: pd.DataFrame,
                          sites: pd.DataFrame,
                          imp_doses_per_week: dict,
                          imp_kit_doses: dict,
                          enroll_accel: float = 1.0,
                          supply_delay_weeks: int = 0,
                          extra_sites: int = 0,
                          horizon_weeks: int = 24,
                          use_simpy: bool = True) -> pd.DataFrame:
    """
    what-if depot→site 배분 시뮬레이션.

    파라미터:
      enroll_accel       : 등록 가속 배수 (1.0=기준, 1.5=50% 가속)
      supply_delay_weeks : depot 보급 지연(주)
      extra_sites        : 추가 site 수(평균 등록 속도로 가정)
      use_simpy          : simpy 설치 시 이산사건 시뮬레이션, 아니면 결정론적 fallback

    반환: week_index, total_demand_kits, total_resupplied_kits,
          backorder_kits, stockout_flag
    """
    # 기준 주차별 총 소요(kit) 계산 — forecast 재사용
    fc = forecast_resupply(enrollment, imp_doses_per_week, imp_kit_doses,
                           horizon_weeks=horizon_weeks)
    if fc.empty:
        return pd.DataFrame()
    weekly = (fc.groupby("week_index")["kits_needed"].sum()
                .reindex(range(0, horizon_weeks + 1), method="ffill")
                .fillna(0))
    # 가속/추가 site 반영 (수요 측 배수)
    site_factor = 1.0 + (extra_sites / max(len(sites["site_id"].unique()), 1))
    demand = (weekly.diff().fillna(weekly).clip(lower=0) * enroll_accel * site_factor)

    if use_simpy:
        try:
            return _simpy_distribution(demand, supply_delay_weeks, horizon_weeks)
        except ImportError:
            pass  # fallback below
    return _deterministic_distribution(demand, supply_delay_weeks, horizon_weeks)


def _deterministic_distribution(demand: pd.Series, supply_delay_weeks: int,
                                horizon_weeks: int) -> pd.DataFrame:
    """결정론적 fallback: 주 단위 수요/보급/backorder 누적."""
    rows = []
    backorder = 0.0
    for w in range(0, horizon_weeks + 1):
        d = float(demand.get(w, 0.0))
        # 보급은 supply_delay_weeks 만큼 지연 도착 (요청=demand)
        resupplied = float(demand.get(w - supply_delay_weeks, 0.0)) if w >= supply_delay_weeks else 0.0
        backorder += d - resupplied
        backorder = max(backorder, 0.0)
        rows.append({
            "week_index": w,
            "total_demand_kits": round(d, 1),
            "total_resupplied_kits": round(resupplied, 1),
            "backorder_kits": round(backorder, 1),
            "stockout_flag": backorder > 0,
        })
    return pd.DataFrame(rows)


def _simpy_distribution(demand: pd.Series, supply_delay_weeks: int,
                        horizon_weeks: int) -> pd.DataFrame:
    """simpy 이산사건 시뮬레이션: depot→site 보급 파이프라인."""
    import simpy  # 미설치 시 ImportError → caller가 fallback

    results = {}
    env = simpy.Environment()
    pipeline = simpy.Store(env)  # 보급 in-transit
    state = {"backorder": 0.0}

    def demand_proc():
        for w in range(0, horizon_weeks + 1):
            d = float(demand.get(w, 0.0))
            # 보급 요청을 supply_delay_weeks 후 도착하도록 스케줄
            env.process(supply_proc(w + supply_delay_weeks, d))
            arrived = 0.0
            # 이번 주 도착분 회수
            while len(pipeline.items) > 0 and pipeline.items[0][0] <= w:
                t, q = yield pipeline.get()
                arrived += q
            state["backorder"] = max(state["backorder"] + d - arrived, 0.0)
            results[w] = {
                "week_index": w,
                "total_demand_kits": round(d, 1),
                "total_resupplied_kits": round(arrived, 1),
                "backorder_kits": round(state["backorder"], 1),
                "stockout_flag": state["backorder"] > 0,
            }
            yield env.timeout(1)

    def supply_proc(arrive_week, qty):
        yield env.timeout(0)
        pipeline.put((arrive_week, qty))

    env.process(demand_proc())
    env.run(until=horizon_weeks + 1)
    return pd.DataFrame([results[w] for w in sorted(results)])


# ===========================================================================
# 헬퍼: IMP 마스터 (app/data 공통)
# ===========================================================================
IMP_MASTER = {
    "insulin_glargine": {"label": "인슐린 글라진(주사 펜)", "doses_per_week": 7,
                         "kit_doses": 30, "excursion_budget_h": 720},
    "semaglutide": {"label": "세마글루타이드 GLP-1RA(펜)", "doses_per_week": 1,
                    "kit_doses": 4, "excursion_budget_h": 1344},
    "tirzepatide": {"label": "터제파타이드(바이알)", "doses_per_week": 1,
                    "kit_doses": 4, "excursion_budget_h": 504},
}


def imp_dicts():
    """엔진 함수용 IMP 파라미터 dict들 반환."""
    dpw = {k: v["doses_per_week"] for k, v in IMP_MASTER.items()}
    kd = {k: v["kit_doses"] for k, v in IMP_MASTER.items()}
    budget = {k: v["excursion_budget_h"] for k, v in IMP_MASTER.items()}
    return dpw, kd, budget
