"""HTML report generator — quarterly obesity pharma roundup."""

from collections import defaultdict
from datetime import datetime
from html import escape

from . import svgcharts


DISCLAIMER = (
    "본 도구는 연구·교육 목적의 참고용이며, 비만약물 처방·임상 의사결정 근거로 "
    "사용할 수 없습니다. 모든 학회 abstract / NCT 번호 / 수치는 합성된 mock 데이터이며 "
    "실제 환자·기업 데이터를 인용하지 않았습니다."
)


def _quarter_of(month_str):
    """Convert YYYY-MM to YYYY-Qn."""
    if not month_str:
        return None
    try:
        y, m = month_str.split("-")
        q = (int(m) - 1) // 3 + 1
        return f"{y}-Q{q}"
    except Exception:
        return None


def filter_by_quarter(abstracts, quarter):
    """Keep abstracts whose conference month is in the given quarter.

    quarter format: 'YYYY-Qn' or None (no filter).
    """
    if not quarter:
        return list(abstracts)
    return [a for a in abstracts if _quarter_of(a.get("_conference_month")) == quarter]


def _bar_data(abstracts):
    """Aggregate weight-loss % per drug (mean across abstracts)."""
    by_drug = defaultdict(list)
    phase_by_drug = {}
    for a in abstracts:
        wl = a.get("weight_loss_pct")
        if wl is None:
            continue
        for d in a.get("drugs", []):
            by_drug[d].append(float(wl))
            # Pick most-advanced phase seen.
            phase_by_drug[d] = a.get("phase") or phase_by_drug.get(d, "?")
    rows = []
    for drug, vals in by_drug.items():
        mean_val = sum(vals) / len(vals)
        rows.append((drug, mean_val, phase_by_drug.get(drug, "?")))
    rows.sort(key=lambda x: -x[1])
    return rows[:12]  # top 12 for readability


def _phase_transitions(abstracts):
    """Count abstract distribution across phases (one node per phase)."""
    counts = defaultdict(int)
    for a in abstracts:
        ph = a.get("phase")
        if ph:
            counts[ph] += 1

    # Build "transitions" approximating progression: count flow between adjacent phases
    # by looking at drugs that appear in two phases.
    drug_phases = defaultdict(set)
    for a in abstracts:
        for d in a.get("drugs", []):
            if a.get("phase"):
                drug_phases[d].add(a["phase"])

    order = ["Phase 1", "Phase 1b", "Phase 1/2", "Phase 2",
             "Phase 2a", "Phase 2b", "Phase 2/3", "Phase 3", "Phase 4"]
    rank = {p: i for i, p in enumerate(order)}

    transitions = []
    for drug, phases in drug_phases.items():
        if len(phases) < 2:
            continue
        sorted_phases = sorted(phases, key=lambda p: rank.get(p, 99))
        for i in range(len(sorted_phases) - 1):
            transitions.append({
                "from_phase": sorted_phases[i],
                "to_phase": sorted_phases[i + 1],
                "count": 1,
                "drug": drug,
            })

    # Aggregate same from→to pairs.
    agg = defaultdict(lambda: {"count": 0, "drugs": []})
    for t in transitions:
        key = (t["from_phase"], t["to_phase"])
        agg[key]["count"] += 1
        agg[key]["drugs"].append(t["drug"])

    out = []
    for (fp, tp), v in agg.items():
        out.append({
            "from_phase": fp,
            "to_phase": tp,
            "count": v["count"],
            "drugs": v["drugs"],
        })
    return out, counts


def _section_for_conference(abstracts, conf_name):
    rows = [a for a in abstracts if a.get("_conference") == conf_name]
    if not rows:
        return ""
    parts = [f'<h3 class="conf">{escape(conf_name)} <span class="count">({len(rows)} abstracts)</span></h3>']
    parts.append('<table class="abstracts"><thead><tr>'
                 '<th>Abstract ID</th><th>Title</th><th>Drug(s)</th>'
                 '<th>Phase</th><th>Weight loss %</th><th>GI AE %</th>'
                 '<th>Sponsor</th><th>Linked NCT</th></tr></thead><tbody>')
    for r in rows:
        nct_links = ", ".join(t.get("nct_id", "") for t in r.get("linked_trials", []))
        wl = r.get("weight_loss_pct")
        ae = r.get("gi_ae_pct")
        parts.append("<tr>")
        parts.append(f'<td><code>{escape(str(r.get("abstract_id", "")))}</code></td>')
        parts.append(f'<td>{escape(r.get("title", ""))}</td>')
        parts.append(f'<td>{escape(", ".join(r.get("drugs", []) or ["—"]))}</td>')
        parts.append(f'<td>{escape(str(r.get("phase") or "—"))}</td>')
        parts.append(f'<td>{("-%.1f%%" % float(wl)) if wl is not None else "—"}</td>')
        parts.append(f'<td>{("%.1f%%" % float(ae)) if ae is not None else "—"}</td>')
        parts.append(f'<td>{escape(", ".join(r.get("sponsors", []) or ["—"]))}</td>')
        parts.append(f'<td>{escape(nct_links) or "—"}</td>')
        parts.append("</tr>")
    parts.append("</tbody></table>")
    return "".join(parts)


def _drug_summary_table(abstracts):
    by_drug = defaultdict(lambda: {
        "abstracts": 0, "phases": set(), "wl": [], "ae": [],
        "sponsors": set(), "ncts": set(),
    })
    for a in abstracts:
        for d in a.get("drugs", []):
            entry = by_drug[d]
            entry["abstracts"] += 1
            if a.get("phase"):
                entry["phases"].add(a["phase"])
            if a.get("weight_loss_pct") is not None:
                entry["wl"].append(float(a["weight_loss_pct"]))
            if a.get("gi_ae_pct") is not None:
                entry["ae"].append(float(a["gi_ae_pct"]))
            for s in a.get("sponsors", []):
                entry["sponsors"].add(s)
            for t in a.get("linked_trials", []):
                entry["ncts"].add(t.get("nct_id", ""))
    rows = sorted(
        by_drug.items(),
        key=lambda kv: -(sum(kv[1]["wl"]) / len(kv[1]["wl"]) if kv[1]["wl"] else 0)
    )
    parts = ['<table class="drug-summary"><thead><tr>'
            '<th>Drug</th><th>#Abstracts</th><th>Phase(s)</th>'
            '<th>Mean weight loss</th><th>Mean GI AE</th>'
            '<th>Sponsor(s)</th><th>NCT links</th></tr></thead><tbody>']
    for drug, e in rows:
        mean_wl = sum(e["wl"]) / len(e["wl"]) if e["wl"] else None
        mean_ae = sum(e["ae"]) / len(e["ae"]) if e["ae"] else None
        parts.append("<tr>")
        parts.append(f'<td><strong>{escape(drug)}</strong></td>')
        parts.append(f'<td>{e["abstracts"]}</td>')
        parts.append(f'<td>{escape(", ".join(sorted(e["phases"])) or "—")}</td>')
        parts.append(f'<td>{("-%.1f%%" % mean_wl) if mean_wl is not None else "—"}</td>')
        parts.append(f'<td>{("%.1f%%" % mean_ae) if mean_ae is not None else "—"}</td>')
        parts.append(f'<td>{escape(", ".join(sorted(e["sponsors"])) or "—")}</td>')
        parts.append(f'<td>{escape(", ".join(sorted(e["ncts"])) or "—")}</td>')
        parts.append("</tr>")
    parts.append("</tbody></table>")
    return "".join(parts)


def _target_breakdown(abstracts):
    by_target = defaultdict(int)
    for a in abstracts:
        for t in a.get("targets", []):
            by_target[t] += 1
    rows = sorted(by_target.items(), key=lambda kv: -kv[1])
    if not rows:
        return ""
    parts = ['<table class="target-breakdown"><thead><tr><th>Target</th><th>Abstract count</th></tr></thead><tbody>']
    for t, c in rows:
        parts.append(f"<tr><td>{escape(t)}</td><td>{c}</td></tr>")
    parts.append("</tbody></table>")
    return "".join(parts)


CSS = """
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Arial, sans-serif;
       margin: 0; padding: 0; color: #1a1a1a; background: #f9fafb; }
header { background: linear-gradient(120deg, #1e3a8a, #1e40af); color: white;
         padding: 26px 40px; }
header h1 { margin: 0; font-size: 22px; font-weight: 700; }
header p { margin: 4px 0 0; font-size: 13px; opacity: 0.9; }
main { max-width: 1100px; margin: 0 auto; padding: 24px 40px 60px; }
section { background: white; border-radius: 8px; padding: 20px 24px;
          margin-bottom: 22px; box-shadow: 0 1px 2px rgba(0,0,0,0.05); }
h2 { font-size: 17px; margin: 0 0 12px; color: #1e3a8a; border-bottom: 1px solid #e5e7eb; padding-bottom: 6px; }
h3.conf { font-size: 14px; margin: 18px 0 8px; color: #1e40af; }
.count { color: #6b7280; font-weight: normal; font-size: 12px; }
table { border-collapse: collapse; width: 100%; font-size: 12px; margin-bottom: 8px; }
th { background: #f3f4f6; text-align: left; padding: 6px 8px; border-bottom: 1px solid #d1d5db; font-weight: 600; }
td { padding: 6px 8px; border-bottom: 1px solid #f0f0f0; vertical-align: top; }
code { font-family: "SF Mono", Consolas, monospace; font-size: 11px; background: #f3f4f6; padding: 1px 4px; border-radius: 3px; }
.kpis { display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 20px; }
.kpi { background: white; border-radius: 8px; padding: 14px; box-shadow: 0 1px 2px rgba(0,0,0,0.05); }
.kpi .v { font-size: 22px; font-weight: 700; color: #1e3a8a; }
.kpi .l { font-size: 11px; color: #6b7280; text-transform: uppercase; letter-spacing: 0.5px; }
.disclaimer { background: #fef3c7; border-left: 3px solid #d97706; padding: 12px 16px;
              border-radius: 4px; font-size: 12px; color: #78350f; line-height: 1.55; }
footer { text-align: center; font-size: 11px; color: #9ca3af; padding: 24px 0; }
.svg-wrap { overflow-x: auto; }
"""


def render_report(abstracts, quarter=None, generated_at=None):
    """Build the full HTML report string."""
    if generated_at is None:
        generated_at = datetime.now().strftime("%Y-%m-%d %H:%M")

    filtered = filter_by_quarter(abstracts, quarter)

    n_abstracts = len(filtered)
    n_drugs = len(set(d for a in filtered for d in a.get("drugs", [])))
    n_trials = len(set(t.get("nct_id") for a in filtered for t in a.get("linked_trials", [])))
    n_sponsors = len(set(s for a in filtered for s in a.get("sponsors", [])))

    bar = svgcharts.weight_loss_bar_chart(_bar_data(filtered))
    transitions, _phase_counts = _phase_transitions(filtered)
    sankey = svgcharts.phase_transition_diagram(transitions)

    conf_sections = []
    for conf_name in ["ADA 2026", "EASD 2026", "Obesity Week 2026", "ENDO 2026", "ECO 2026"]:
        s = _section_for_conference(filtered, conf_name)
        if s:
            conf_sections.append(s)

    quarter_label = quarter or "all quarters"

    html_parts = [
        '<!DOCTYPE html>',
        '<html lang="en"><head>',
        '<meta charset="UTF-8">',
        f'<title>ObesityPharmaWatch — {escape(quarter_label)}</title>',
        f'<style>{CSS}</style>',
        '</head><body>',
        '<header>',
        '<h1>ObesityPharmaWatch — quarterly obesity pharma abstract roundup</h1>',
        f'<p>Quarter: {escape(quarter_label)} · generated {escape(generated_at)} · '
        '5 conferences (ADA · EASD · Obesity Week · ENDO · ECO) · synthetic mock data only</p>',
        '</header>',
        '<main>',
        '<section><div class="kpis">',
        f'<div class="kpi"><div class="v">{n_abstracts}</div><div class="l">Abstracts</div></div>',
        f'<div class="kpi"><div class="v">{n_drugs}</div><div class="l">Distinct drugs</div></div>',
        f'<div class="kpi"><div class="v">{n_trials}</div><div class="l">Linked CTG trials</div></div>',
        f'<div class="kpi"><div class="v">{n_sponsors}</div><div class="l">Sponsors</div></div>',
        '</div></section>',

        '<section>',
        '<h2>Drug × phase × weight-loss summary (top 12)</h2>',
        f'<div class="svg-wrap">{bar}</div>',
        '</section>',

        '<section>',
        '<h2>Pipeline phase transitions</h2>',
        f'<div class="svg-wrap">{sankey}</div>',
        '</section>',

        '<section>',
        '<h2>Drug-level summary table</h2>',
        _drug_summary_table(filtered),
        '</section>',

        '<section>',
        '<h2>Target-mechanism breakdown</h2>',
        _target_breakdown(filtered),
        '</section>',

        '<section>',
        '<h2>Per-conference abstracts</h2>',
        "".join(conf_sections),
        '</section>',

        '<section>',
        '<h2>Disclaimer</h2>',
        f'<p class="disclaimer">{escape(DISCLAIMER)}</p>',
        '</section>',

        '</main>',
        '<footer>ObesityPharmaWatch v0.1.0 · 2026 metabolic daily idea pipeline · '
        'mock data only · not for clinical use</footer>',
        '</body></html>',
    ]
    return "".join(html_parts)
