"""Inline SVG chart generator (no external CDN, no plotly)."""

from html import escape


def _svg_open(width, height, extra=""):
    return (
        f'<svg xmlns="http://www.w3.org/2000/svg" '
        f'viewBox="0 0 {width} {height}" '
        f'width="{width}" height="{height}" '
        f'role="img" aria-label="ObesityPharmaWatch chart" {extra}>'
    )


def weight_loss_bar_chart(data, width=720, height=380):
    """Horizontal bar chart of mean weight loss % per drug.

    Args:
        data: list of (drug_label, percent_loss, phase_label) tuples.
              percent_loss is a float (positive number representing %).
    """
    if not data:
        return f'{_svg_open(width, height)}<text x="20" y="40" font-family="Arial" font-size="14" fill="#666">No data</text></svg>'

    margin_left = 220
    margin_top = 40
    margin_right = 60
    margin_bottom = 50
    plot_w = width - margin_left - margin_right
    plot_h = height - margin_top - margin_bottom
    n = len(data)
    bar_h = max(10, plot_h / n - 6)

    max_val = max(d[1] for d in data) if data else 1.0
    if max_val <= 0:
        max_val = 1.0
    # Round up axis max to nearest 5.
    axis_max = (int(max_val / 5) + 1) * 5

    parts = [_svg_open(width, height)]
    parts.append(
        f'<text x="{width//2}" y="22" text-anchor="middle" '
        f'font-family="Arial" font-size="15" font-weight="600" fill="#1a1a1a">'
        f'Mean placebo-adjusted weight loss (%) by drug — synthetic mock data</text>'
    )

    # X-axis grid + labels
    for i in range(0, axis_max + 1, 5):
        x = margin_left + (i / axis_max) * plot_w
        parts.append(
            f'<line x1="{x:.1f}" y1="{margin_top}" x2="{x:.1f}" '
            f'y2="{margin_top + plot_h}" stroke="#e5e7eb" stroke-width="1"/>'
        )
        parts.append(
            f'<text x="{x:.1f}" y="{margin_top + plot_h + 18}" text-anchor="middle" '
            f'font-family="Arial" font-size="11" fill="#555">-{i}%</text>'
        )

    # Color per phase
    phase_colors = {
        "Phase 1": "#cbd5e1",
        "Phase 1b": "#94a3b8",
        "Phase 1/2": "#94a3b8",
        "Phase 2": "#60a5fa",
        "Phase 2a": "#3b82f6",
        "Phase 2b": "#2563eb",
        "Phase 2/3": "#1d4ed8",
        "Phase 3": "#1e40af",
        "Phase 4": "#0c4a6e",
    }

    for idx, (label, val, phase) in enumerate(data):
        y = margin_top + idx * (plot_h / n) + 4
        bw = (val / axis_max) * plot_w
        color = phase_colors.get(phase, "#64748b")
        parts.append(
            f'<rect x="{margin_left}" y="{y:.1f}" width="{bw:.1f}" '
            f'height="{bar_h:.1f}" fill="{color}" rx="3"/>'
        )
        parts.append(
            f'<text x="{margin_left - 8}" y="{y + bar_h/2 + 4:.1f}" '
            f'text-anchor="end" font-family="Arial" font-size="12" fill="#1a1a1a">'
            f'{escape(label)}</text>'
        )
        parts.append(
            f'<text x="{margin_left + bw + 6:.1f}" y="{y + bar_h/2 + 4:.1f}" '
            f'font-family="Arial" font-size="11" fill="#1a1a1a">'
            f'-{val:.1f}% · {escape(phase or "?")}</text>'
        )

    # Legend
    legend_y = height - 18
    legend_items = [("Phase 1/1b", "#94a3b8"), ("Phase 2", "#3b82f6"),
                    ("Phase 2b", "#2563eb"), ("Phase 3", "#1e40af")]
    lx = margin_left
    for name, col in legend_items:
        parts.append(f'<rect x="{lx}" y="{legend_y - 10}" width="12" height="12" fill="{col}" rx="2"/>')
        parts.append(
            f'<text x="{lx + 18}" y="{legend_y}" font-family="Arial" '
            f'font-size="11" fill="#444">{name}</text>'
        )
        lx += 110

    parts.append("</svg>")
    return "".join(parts)


def phase_transition_diagram(transitions, width=720, height=320):
    """Sankey-like phase transition diagram.

    Args:
        transitions: list of dicts with keys: from_phase, to_phase, count, drugs(list)
    """
    phases_order = ["Phase 1", "Phase 1b", "Phase 1/2", "Phase 2",
                    "Phase 2a", "Phase 2b", "Phase 2/3", "Phase 3", "Phase 4"]
    # Limit to phases actually present (preserving order).
    used = []
    for p in phases_order:
        for t in transitions:
            if t["from_phase"] == p or t["to_phase"] == p:
                if p not in used:
                    used.append(p)
                break

    if not used:
        return f'{_svg_open(width, height)}<text x="20" y="40" font-family="Arial" font-size="14" fill="#666">No phase transitions in this quarter</text></svg>'

    margin_top = 40
    margin_bottom = 30
    margin_x = 60
    plot_w = width - 2 * margin_x
    col_x = {p: margin_x + (i * plot_w / max(1, len(used) - 1)) for i, p in enumerate(used)}
    node_w = 90
    node_h = 28

    parts = [_svg_open(width, height)]
    parts.append(
        f'<text x="{width//2}" y="22" text-anchor="middle" '
        f'font-family="Arial" font-size="15" font-weight="600" fill="#1a1a1a">'
        f'Pipeline phase transitions in quarter — synthetic mock data</text>'
    )

    # Draw nodes (one row in middle).
    node_y = (height - margin_bottom + margin_top) / 2 - node_h / 2
    for p in used:
        x = col_x[p] - node_w / 2
        parts.append(
            f'<rect x="{x:.1f}" y="{node_y:.1f}" width="{node_w}" height="{node_h}" '
            f'fill="#1e40af" rx="6"/>'
        )
        parts.append(
            f'<text x="{col_x[p]:.1f}" y="{node_y + node_h/2 + 4:.1f}" '
            f'text-anchor="middle" font-family="Arial" font-size="11" '
            f'fill="white">{escape(p)}</text>'
        )

    # Draw transitions as curved arrows above the nodes.
    max_count = max((t["count"] for t in transitions), default=1)
    for t in transitions:
        if t["from_phase"] not in col_x or t["to_phase"] not in col_x:
            continue
        x1 = col_x[t["from_phase"]] + node_w / 2
        x2 = col_x[t["to_phase"]] - node_w / 2
        if x2 < x1:
            x1, x2 = x1, x2  # backward transitions still draw
        y_mid = node_y - 25
        cx1 = (x1 + x2) / 2
        sw = 1 + 4 * (t["count"] / max_count)
        parts.append(
            f'<path d="M {x1:.1f} {node_y:.1f} '
            f'C {cx1:.1f} {y_mid:.1f}, {cx1:.1f} {y_mid:.1f}, '
            f'{x2:.1f} {node_y:.1f}" '
            f'stroke="#3b82f6" stroke-width="{sw:.1f}" fill="none" opacity="0.7"/>'
        )
        # Label
        parts.append(
            f'<text x="{cx1:.1f}" y="{y_mid - 4:.1f}" '
            f'text-anchor="middle" font-family="Arial" font-size="10" fill="#1e3a8a">'
            f'n={t["count"]}</text>'
        )

    # Footnote
    parts.append(
        f'<text x="{margin_x}" y="{height - 8}" font-family="Arial" '
        f'font-size="10" fill="#777">Curve thickness proportional to abstract count; '
        f'all values synthetic.</text>'
    )
    parts.append("</svg>")
    return "".join(parts)
