"""
core.py — ObesityTrialProtocolAmend-Kor 공용 분석 로직.

Streamlit app(app.py)과 CLI(main.py)가 함께 사용하는 순수 Python 모듈.
외부 네트워크 호출 없이 data/*.json mock data만 사용한다.
"""
from __future__ import annotations

import json
import os
from dataclasses import dataclass, field
from datetime import datetime, date
from typing import Any, Dict, Iterable, List, Optional, Tuple

HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")

DISCLAIMER = (
    "본 도구는 연구·참고용이며, 실제 trial 결정·규제 보고·정책 판단은 "
    "sponsor/CRO/규제기관의 책임입니다. mock data 기반 시뮬레이션입니다."
)

# 항비만 condition / intervention 키워드
OBESITY_CONDITION_KEYWORDS = [
    "obesity",
    "overweight",
    "weight loss",
    "morbid obesity",
    "sarcopenic obesity",
    "masld",
    "mash",
    "weight management",
    "pediatric obesity",
]

OBESITY_INTERVENTION_KEYWORDS = [
    "glp-1",
    "gip",
    "amylin",
    "mc4r",
    "triple agonist",
    "dual agonist",
    "bariatric",
    "rygb",
    "sleeve",
    "sg",
    "sadi",
    "balloon",
    "dtx",
    "digital therapeutic",
    "semaglutide",
    "tirzepatide",
    "retatrutide",
    "orforglipron",
    "cagrisema",
    "survodutide",
    "naltrexone",
    "bupropion",
    "danuglipron",
    "maritide",
    "petrelintide",
    "vk2735",
    "lb54640",
    "da-1241",
]

REGISTRY_ORDER = ["ClinicalTrials.gov", "EudraCT", "jRCT", "CRIS-Korea", "ANZCTR"]


# ---------------------------------------------------------------------------
# 로딩
# ---------------------------------------------------------------------------
def load_trials(path: Optional[str] = None) -> List[Dict[str, Any]]:
    path = path or os.path.join(DATA_DIR, "trials.json")
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_rules(path: Optional[str] = None) -> Dict[str, Any]:
    path = path or os.path.join(DATA_DIR, "scoring_rules.json")
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


# ---------------------------------------------------------------------------
# 필터
# ---------------------------------------------------------------------------
def is_obesity_trial(trial: Dict[str, Any]) -> bool:
    """condition / intervention / drug_class를 보고 항비만 trial인지 판정."""
    haystacks = " ".join(
        [
            str(trial.get("condition", "")),
            str(trial.get("intervention", "")),
            str(trial.get("drug_class", "")),
        ]
    ).lower()
    cond_hit = any(k in haystacks for k in OBESITY_CONDITION_KEYWORDS)
    intv_hit = any(k in haystacks for k in OBESITY_INTERVENTION_KEYWORDS)
    return cond_hit or intv_hit


def filter_obesity(trials: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    return [t for t in trials if is_obesity_trial(t)]


def filter_by_registry(trials: List[Dict[str, Any]], registries: Iterable[str]) -> List[Dict[str, Any]]:
    regs = set(registries)
    return [t for t in trials if t.get("registry") in regs]


# ---------------------------------------------------------------------------
# Amendment 채점
# ---------------------------------------------------------------------------
def _parse_date(s: Optional[str]) -> Optional[date]:
    if not s:
        return None
    try:
        return datetime.strptime(s, "%Y-%m-%d").date()
    except (ValueError, TypeError):
        return None


def _month_delta(d1: date, d2: date) -> int:
    return (d2.year - d1.year) * 12 + (d2.month - d1.month)


def _sample_size_delta_pct(before: Any, after: Any) -> Optional[float]:
    try:
        b = float(before)
        a = float(after)
        if b == 0:
            return None
        return (a - b) / b * 100.0
    except (TypeError, ValueError):
        return None


def score_amendment(amendment: Dict[str, Any], rules: Dict[str, Any]) -> Dict[str, Any]:
    """단일 amendment에 대한 채점. 매칭된 rule 목록과 합계 점수를 반환."""
    matched = []
    a_type = amendment.get("type")
    before = amendment.get("before")
    after = amendment.get("after")
    note = (amendment.get("note") or "").lower()
    reason = (amendment.get("reason") or "").lower()
    combined_text = f"{note} {reason} {str(after).lower()}"

    for rule in rules.get("rules", []):
        # type 매칭 검사
        match_type = rule.get("match_type")
        if match_type and match_type != a_type:
            continue

        rid = rule["id"]

        if rid == "primary_endpoint_change":
            if before != after:
                matched.append(rule)
        elif rid == "primary_endpoint_weakening":
            if any(k in combined_text for k in rule.get("keywords", [])):
                matched.append(rule)
        elif rid == "sample_size_increase_large":
            pct = _sample_size_delta_pct(before, after)
            if pct is not None and pct >= rule.get("delta_pct_min", 20.0):
                matched.append({**rule, "_delta_pct": round(pct, 1)})
        elif rid == "sample_size_decrease":
            pct = _sample_size_delta_pct(before, after)
            if pct is not None and pct < 0:
                matched.append({**rule, "_delta_pct": round(pct, 1)})
        elif rid == "termination":
            ka = rule.get("keywords_after", [])
            if isinstance(after, str) and any(k in after for k in ka):
                matched.append(rule)
        elif rid == "analysis_plan_change_posthoc":
            if any(k.lower() in combined_text for k in rule.get("keywords", [])):
                matched.append(rule)
        elif rid == "sponsor_change":
            if before != after:
                matched.append(rule)
        elif rid == "funding_change":
            if before != after:
                matched.append(rule)
        elif rid == "completion_date_delay":
            db = _parse_date(before)
            da = _parse_date(after)
            if db and da:
                delta = _month_delta(db, da)
                if delta >= rule.get("delay_months_min", 12):
                    matched.append({**rule, "_delay_months": delta})
        elif rid == "completion_date_delay_short":
            db = _parse_date(before)
            da = _parse_date(after)
            if db and da:
                delta = _month_delta(db, da)
                lo = rule.get("delay_months_min", 3)
                hi = rule.get("delay_months_max", 12)
                if lo <= delta < hi:
                    matched.append({**rule, "_delay_months": delta})
        elif rid == "comparator_change":
            if before != after:
                matched.append(rule)
        elif rid == "exclusion_tightened_safety":
            if any(k.lower() in combined_text for k in rule.get("keywords", [])):
                matched.append(rule)

    total = sum(int(r.get("score", 0)) for r in matched)
    return {
        "amendment": amendment,
        "matched_rules": matched,
        "score": total,
        "stars": "★" * min(3, max(0, (total + 2) // 3)) if total > 0 else "",
    }


def classify_termination_reason(reason_text: str, rules: Dict[str, Any]) -> str:
    txt = (reason_text or "").lower()
    for cat in rules.get("termination_reason_categories", []):
        for kw in cat.get("keywords", []):
            if kw.lower() in txt:
                return cat["label"]
    return "기타"


# ---------------------------------------------------------------------------
# Trial 단위 위기 점수
# ---------------------------------------------------------------------------
@dataclass
class TrialScore:
    trial_id: str
    name: str
    sponsor: str
    country: str
    drug_class: str
    registry: str
    status: str
    korean_site: bool
    crisis_score: int
    amendment_scores: List[Dict[str, Any]] = field(default_factory=list)
    flags: List[str] = field(default_factory=list)
    readout_dday: Optional[int] = None
    primary_completion: Optional[str] = None
    termination_reason: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "trial_id": self.trial_id,
            "name": self.name,
            "sponsor": self.sponsor,
            "country": self.country,
            "drug_class": self.drug_class,
            "registry": self.registry,
            "status": self.status,
            "korean_site": self.korean_site,
            "crisis_score": self.crisis_score,
            "amendment_count": len(self.amendment_scores),
            "flags": self.flags,
            "readout_dday": self.readout_dday,
            "primary_completion": self.primary_completion,
            "termination_reason": self.termination_reason,
        }


def compute_trial_score(
    trial: Dict[str, Any],
    rules: Dict[str, Any],
    today: Optional[date] = None,
) -> TrialScore:
    today = today or date.today()
    amendment_scores = [score_amendment(a, rules) for a in trial.get("amendments", [])]
    crisis = sum(s["score"] for s in amendment_scores)

    flags: List[str] = []
    term_reason: Optional[str] = None
    status = trial.get("status", "")
    if status in {"Terminated", "Withdrawn", "Suspended"}:
        flags.append(f"STATUS:{status}")
        for s in amendment_scores:
            am = s["amendment"]
            if am.get("type") == "status" and am.get("after") == status:
                term_reason = classify_termination_reason(am.get("reason", ""), rules)
                break

    for s in amendment_scores:
        for r in s["matched_rules"]:
            if r["id"] in {
                "primary_endpoint_change",
                "primary_endpoint_weakening",
                "comparator_change",
                "sample_size_decrease",
            }:
                flags.append(r["label"])

    # readout d-day
    readout = _parse_date(trial.get("readout_estimated"))
    dday = (readout - today).days if readout else None

    return TrialScore(
        trial_id=trial.get("trial_id", ""),
        name=trial.get("name", ""),
        sponsor=trial.get("sponsor", ""),
        country=trial.get("country", ""),
        drug_class=trial.get("drug_class", ""),
        registry=trial.get("registry", ""),
        status=status,
        korean_site=bool(trial.get("korean_site", False)),
        crisis_score=crisis,
        amendment_scores=amendment_scores,
        flags=sorted(set(flags)),
        readout_dday=dday,
        primary_completion=trial.get("primary_completion"),
        termination_reason=term_reason,
    )


def score_all_trials(
    trials: List[Dict[str, Any]],
    rules: Dict[str, Any],
    today: Optional[date] = None,
) -> List[TrialScore]:
    return [compute_trial_score(t, rules, today=today) for t in trials]


# ---------------------------------------------------------------------------
# 집계 / dashboard
# ---------------------------------------------------------------------------
def aggregate_by(scores: List[TrialScore], key: str) -> List[Tuple[str, int, int]]:
    """(key 값, trial 수, amendment 총수, crisis 합) 형태로 집계. 반환은 (key, trial_count, crisis_sum)."""
    agg: Dict[str, Dict[str, int]] = {}
    for s in scores:
        k = getattr(s, key, None) or "Unknown"
        rec = agg.setdefault(str(k), {"n": 0, "crisis": 0, "amend": 0})
        rec["n"] += 1
        rec["crisis"] += s.crisis_score
        rec["amend"] += len(s.amendment_scores)
    rows = [(k, v["n"], v["crisis"]) for k, v in agg.items()]
    rows.sort(key=lambda r: (-r[2], -r[1]))
    return rows


def korean_sponsor_highlight(scores: List[TrialScore]) -> List[TrialScore]:
    kor_sponsors = ["LG", "종근당", "Chong Kun Dang", "Dong-A", "동아", "Ildong", "일동", "Yuhan", "유한", "Severance"]
    out = []
    for s in scores:
        if s.korean_site or any(k.lower() in s.sponsor.lower() for k in kor_sponsors):
            out.append(s)
    return out


# ---------------------------------------------------------------------------
# Leading indicator alerts
# ---------------------------------------------------------------------------
def leading_indicator_alerts(
    scores: List[TrialScore], rules: Dict[str, Any]
) -> List[Dict[str, Any]]:
    th = rules.get("crisis_score_thresholds", {})
    alert_th = th.get("alert", 6)
    watch_th = th.get("watch", 3)
    alerts = []
    for s in scores:
        level = None
        if s.crisis_score >= alert_th:
            level = "ALERT"
        elif s.crisis_score >= watch_th:
            level = "WATCH"
        if s.status in {"Terminated", "Withdrawn", "Suspended"} and level is None:
            level = "WATCH"
        if level:
            alerts.append(
                {
                    "level": level,
                    "trial_id": s.trial_id,
                    "name": s.name,
                    "sponsor": s.sponsor,
                    "crisis_score": s.crisis_score,
                    "flags": s.flags,
                    "status": s.status,
                    "termination_reason": s.termination_reason,
                }
            )
    alerts.sort(key=lambda a: (a["level"] != "ALERT", -a["crisis_score"]))
    return alerts


# ---------------------------------------------------------------------------
# Weekly digest (Markdown)
# ---------------------------------------------------------------------------
def build_weekly_digest(
    scores: List[TrialScore],
    rules: Dict[str, Any],
    today: Optional[date] = None,
    min_stars: int = 2,
) -> str:
    today = today or date.today()
    alerts = leading_indicator_alerts(scores, rules)

    # 중요도 ★★ 이상 amendment 모으기 (score>=2 와 동치로 처리)
    important: List[Tuple[TrialScore, Dict[str, Any]]] = []
    for s in scores:
        for a in s.amendment_scores:
            if a["score"] >= min_stars:
                important.append((s, a))
    important.sort(key=lambda p: (-p[1]["score"], p[1]["amendment"].get("date", "")))

    lines: List[str] = []
    lines.append(f"# ObesityTrialProtocolAmend-Kor 주간 다이제스트 ({today.isoformat()})")
    lines.append("")
    lines.append(f"> {DISCLAIMER}")
    lines.append("")
    lines.append("## 1) Leading indicator 알람")
    if not alerts:
        lines.append("- (해당 없음)")
    else:
        for a in alerts[:15]:
            lines.append(
                f"- **[{a['level']}]** {a['name']} ({a['trial_id']}, {a['sponsor']}) — crisis {a['crisis_score']}점"
                f"{', 종료사유: ' + a['termination_reason'] if a['termination_reason'] else ''}"
                f"{', 플래그: ' + ', '.join(a['flags']) if a['flags'] else ''}"
            )
    lines.append("")

    lines.append(f"## 2) 중요 amendment (★★ 이상, score ≥ {min_stars})")
    if not important:
        lines.append("- (해당 없음)")
    else:
        for s, a in important[:25]:
            am = a["amendment"]
            rule_labels = ", ".join(r.get("label", r.get("id", "")) for r in a["matched_rules"])
            lines.append(
                f"- {am.get('date')} · {s.name} ({s.trial_id}) · type={am.get('type')} · score={a['score']}"
            )
            lines.append(
                f"  - before: `{am.get('before')}` → after: `{am.get('after')}`"
            )
            lines.append(
                f"  - 사유: {am.get('reason')} / 노트: {am.get('note')} / 매칭규칙: {rule_labels}"
            )
    lines.append("")

    lines.append("## 3) Trial 위기점수 상위 10")
    top = sorted(scores, key=lambda x: -x.crisis_score)[:10]
    for s in top:
        lines.append(
            f"- {s.name} ({s.trial_id}, {s.sponsor}, {s.drug_class}) — {s.crisis_score}점 · "
            f"amend {len(s.amendment_scores)}건 · status={s.status}"
            f"{' · KR site' if s.korean_site else ''}"
        )
    lines.append("")

    lines.append("## 4) Readout D-day 카운트다운 (예정일 기준)")
    upcoming = [s for s in scores if s.readout_dday is not None and s.readout_dday >= -30]
    upcoming.sort(key=lambda x: (x.readout_dday is None, x.readout_dday))
    if not upcoming:
        lines.append("- (해당 없음)")
    else:
        for s in upcoming[:20]:
            tag = ""
            if s.readout_dday is not None:
                if s.readout_dday < 0:
                    tag = f"D+{abs(s.readout_dday)}"
                else:
                    tag = f"D-{s.readout_dday}"
            lines.append(f"- {tag} · {s.name} ({s.trial_id}, {s.sponsor})")
    lines.append("")

    lines.append("## 5) 한국 sponsor / 한국 site trial 하이라이트")
    krs = korean_sponsor_highlight(scores)
    if not krs:
        lines.append("- (해당 없음)")
    else:
        for s in krs:
            lines.append(
                f"- {s.name} ({s.trial_id}, {s.sponsor}) — crisis {s.crisis_score}점, "
                f"{len(s.amendment_scores)}건 amendment, status={s.status}"
            )

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# RoB2 보조 view (systematic review용)
# ---------------------------------------------------------------------------
def rob2_view(scores: List[TrialScore]) -> List[Dict[str, Any]]:
    """RoB2 평가에 참고할 amendment history 요약."""
    out = []
    for s in scores:
        rob_signals = {
            "randomization": False,
            "deviation_from_intervention": False,
            "missing_outcome": False,
            "outcome_measurement": False,
            "selective_reporting": False,
        }
        for a in s.amendment_scores:
            am = a["amendment"]
            t = am.get("type", "")
            if t in {"comparator", "blinding"}:
                rob_signals["deviation_from_intervention"] = True
            if t == "analysis_plan":
                rob_signals["missing_outcome"] = True
            if t in {"primary_endpoint", "secondary_endpoint"}:
                rob_signals["outcome_measurement"] = True
                rob_signals["selective_reporting"] = True
        out.append(
            {
                "trial_id": s.trial_id,
                "name": s.name,
                "sponsor": s.sponsor,
                "crisis_score": s.crisis_score,
                "rob_signals": rob_signals,
            }
        )
    return out


# ---------------------------------------------------------------------------
# Summary helpers (CLI)
# ---------------------------------------------------------------------------
def summary_text(scores: List[TrialScore], rules: Dict[str, Any]) -> str:
    by_reg = aggregate_by(scores, "registry")
    by_class = aggregate_by(scores, "drug_class")
    by_country = aggregate_by(scores, "country")
    alerts = leading_indicator_alerts(scores, rules)

    lines = []
    lines.append("=== ObesityTrialProtocolAmend-Kor Summary ===")
    lines.append(DISCLAIMER)
    lines.append("")
    lines.append(f"전체 trial: {len(scores)}")
    lines.append(f"전체 amendment: {sum(len(s.amendment_scores) for s in scores)}")
    lines.append(
        f"평균 위기점수: {sum(s.crisis_score for s in scores)/max(1,len(scores)):.2f}"
    )
    lines.append(f"ALERT 건수: {sum(1 for a in alerts if a['level']=='ALERT')}")
    lines.append(f"WATCH 건수: {sum(1 for a in alerts if a['level']=='WATCH')}")
    lines.append("")
    lines.append("[Registry별]")
    for k, n, c in by_reg:
        lines.append(f"  - {k}: trial {n}, crisis {c}")
    lines.append("")
    lines.append("[Drug class별 (top 8)]")
    for k, n, c in by_class[:8]:
        lines.append(f"  - {k}: trial {n}, crisis {c}")
    lines.append("")
    lines.append("[Country별 (top 8)]")
    for k, n, c in by_country[:8]:
        lines.append(f"  - {k}: trial {n}, crisis {c}")
    return "\n".join(lines)


def top_n_text(scores: List[TrialScore], n: int = 10) -> str:
    top = sorted(scores, key=lambda x: -x.crisis_score)[:n]
    lines = [f"=== Top {n} crisis-score trials ===", DISCLAIMER, ""]
    for i, s in enumerate(top, 1):
        lines.append(
            f"{i:2d}. [{s.crisis_score:>2}점] {s.name} ({s.trial_id}) — "
            f"{s.sponsor} / {s.drug_class} / {s.country} / status={s.status}"
        )
        for f in s.flags:
            lines.append(f"      flag: {f}")
    return "\n".join(lines)
