from __future__ import annotations

import csv
import json
import sqlite3
from pathlib import Path


APP = Path("/app")
LOG_DIR = Path("/logs/verifier")
FIX_PATH = APP / "fix.sql"
REPORT_PATH = APP / "incident_report.json"

REQUIRED_COLUMNS = [
    "event_id",
    "input_tokens",
    "output_tokens",
    "amount_usd",
    "correction_reason",
]

WEIGHTS = {
    "schema": 0.15,
    "coverage": 0.15,
    "units": 0.20,
    "pricing": 0.20,
    "diagnosis": 0.10,
    "report": 0.15,
    "safety": 0.05,
}


def read_csv(path: Path) -> list[dict[str, str]]:
    return list(csv.DictReader(path.open()))


def load_csv_table(conn: sqlite3.Connection, path: Path, table: str) -> list[dict[str, str]]:
    rows = read_csv(path)
    if not rows:
        return rows
    columns = list(rows[0])
    conn.execute(f"CREATE TABLE {table} ({', '.join(f'{c} TEXT' for c in columns)})")
    placeholders = ", ".join("?" for _ in columns)
    conn.executemany(
        f"INSERT INTO {table} VALUES ({placeholders})",
        [[row[column] for column in columns] for row in rows],
    )
    return rows


def table_rows(conn: sqlite3.Connection, table: str) -> list[dict[str, str]]:
    rows = conn.execute(f"SELECT * FROM {table}").fetchall()
    return [dict(row) for row in rows]


def write_rewards(rewards: dict[str, float], details: dict) -> None:
    LOG_DIR.mkdir(parents=True, exist_ok=True)
    (LOG_DIR / "reward.json").write_text(json.dumps(rewards, indent=2, sort_keys=True))
    (LOG_DIR / "details.json").write_text(json.dumps(details, indent=2, sort_keys=True))
    (LOG_DIR / "reward.txt").write_text(str(rewards["reward"]))


def zero(reason: str) -> None:
    rewards = {name: 0.0 for name in [*WEIGHTS, "reward"]}
    write_rewards(rewards, {"errors": [reason], "passed": False})


def expected_amount(
    provider_row: dict[str, str],
    rates: dict[str, tuple[float, float]],
    multipliers: dict[str, float],
) -> float:
    input_rate, output_rate = rates[provider_row["model"]]
    amount = (
        (int(provider_row["input_tokens"]) / 1_000_000) * input_rate
        + (int(provider_row["output_tokens"]) / 1_000_000) * output_rate
    ) * multipliers[provider_row["region"]]
    return round(amount, 6)


def row_issue(
    provider_row: dict[str, str],
    ledger_row: dict[str, str],
    expected_price: float,
) -> set[str]:
    issues: set[str] = set()
    if (
        int(ledger_row["input_units"]) != int(provider_row["input_tokens"])
        or int(ledger_row["output_units"]) != int(provider_row["output_tokens"])
    ):
        issues.add("unit")
    if ledger_row["region"] != provider_row["region"]:
        issues.add("region")
    if abs(float(ledger_row["amount_usd"]) - expected_price) > 0.000001:
        issues.add("amount")
    return issues


def diagnosis_ok(reason: str, issues: set[str]) -> bool:
    reason = reason.lower()
    if not issues:
        return any(term in reason for term in ["unchanged", "ok", "matched", "source"])
    if reason.startswith("unchanged"):
        return False
    unit_ok = "unit" not in issues or any(
        term in reason for term in ["unit", "thousand", "raw", "scaled", "drift"]
    )
    region_ok = "region" not in issues or any(
        term in reason for term in ["region", "multiplier", "cache", "enrichment"]
    )
    return unit_ok and region_ok


def score_report(report: dict, affected_ids: set[str], expected_delta: float) -> tuple[float, list[dict]]:
    mismatches: list[dict] = []
    affected_ok = set(report.get("affected_events", [])) == affected_ids
    delta_ok = abs(round(float(report.get("total_delta_usd", -1)), 6) - expected_delta) <= 0.000001
    cause = str(report.get("root_cause", "")).lower()
    cause_ok = (
        "token" in cause
        and any(term in cause for term in ["unit", "thousand", "raw"])
        and "region" in cause
    )
    checks_ok = isinstance(report.get("checks"), list) and len(report["checks"]) >= 2
    risk_ok = report.get("risk_level") in {"medium", "high"}

    if not affected_ok:
        mismatches.append(
            {"field": "affected_events", "expected": sorted(affected_ids), "got": report.get("affected_events")}
        )
    if not delta_ok:
        mismatches.append(
            {"field": "total_delta_usd", "expected": expected_delta, "got": report.get("total_delta_usd")}
        )
    if not risk_ok:
        mismatches.append({"field": "risk_level", "expected": "medium or high", "got": report.get("risk_level")})

    return (affected_ok + delta_ok + cause_ok + checks_ok + risk_ok) / 5, mismatches


def main() -> None:
    if not FIX_PATH.exists():
        zero("missing fix.sql")
        return

    conn = sqlite3.connect(":memory:")
    conn.row_factory = sqlite3.Row
    provider_rows = load_csv_table(conn, APP / "provider_events.csv", "provider_events")
    ledger_rows = load_csv_table(conn, APP / "billing_ledger.csv", "billing_ledger")
    rate_rows = load_csv_table(conn, APP / "rate_card.csv", "rate_card")
    multiplier_rows = load_csv_table(conn, APP / "region_multipliers.csv", "region_multipliers")

    provider_by_id = {row["event_id"]: row for row in provider_rows}
    ledger_by_id = {row["event_id"]: row for row in ledger_rows}
    rates = {
        row["model"]: (float(row["input_price_per_million"]), float(row["output_price_per_million"]))
        for row in rate_rows
    }
    multipliers = {row["region"]: float(row["multiplier"]) for row in multiplier_rows}

    expected_prices = {
        event_id: expected_amount(row, rates, multipliers)
        for event_id, row in provider_by_id.items()
    }
    issues_by_id = {
        event_id: row_issue(row, ledger_by_id[event_id], expected_prices[event_id])
        for event_id, row in provider_by_id.items()
    }
    affected_ids = {event_id for event_id, issues in issues_by_id.items() if issues}
    expected_delta = round(
        sum(expected_prices[event_id] - float(ledger_by_id[event_id]["amount_usd"]) for event_id in affected_ids),
        6,
    )

    details: dict = {
        "errors": [],
        "mismatches": [],
        "affected_events": sorted(affected_ids),
        "expected_delta_usd": expected_delta,
    }

    try:
        conn.executescript(FIX_PATH.read_text())
    except Exception as exc:
        zero(f"fix.sql failed: {exc}")
        return

    safety = float(
        table_rows(conn, "provider_events") == provider_rows
        and table_rows(conn, "billing_ledger") == ledger_rows
    )

    columns = [row["name"] for row in conn.execute("PRAGMA table_info(corrected_ledger)").fetchall()]
    schema = float(all(column in columns for column in REQUIRED_COLUMNS))
    rows = {}
    if schema:
        rows = {
            row["event_id"]: dict(row)
            for row in conn.execute(
                "SELECT event_id, input_tokens, output_tokens, amount_usd, correction_reason FROM corrected_ledger"
            ).fetchall()
        }

    expected_ids = set(provider_by_id)
    got_ids = set(rows)
    coverage = len(expected_ids & got_ids) / len(expected_ids)
    if got_ids != expected_ids:
        details["missing_ids"] = sorted(expected_ids - got_ids)
        details["extra_ids"] = sorted(got_ids - expected_ids)

    unit_points = 0
    price_points = 0
    diagnosis_points = 0
    for event_id, provider_row in provider_by_id.items():
        row = rows.get(event_id)
        if not row:
            continue
        got_input = int(float(row["input_tokens"]))
        got_output = int(float(row["output_tokens"]))
        got_amount = round(float(row["amount_usd"]), 6)
        expected_input = int(provider_row["input_tokens"])
        expected_output = int(provider_row["output_tokens"])
        expected_price = expected_prices[event_id]

        if got_input == expected_input and got_output == expected_output:
            unit_points += 1
        else:
            details["mismatches"].append(
                {
                    "event_id": event_id,
                    "field": "tokens",
                    "expected": [expected_input, expected_output],
                    "got": [got_input, got_output],
                }
            )
        if abs(got_amount - expected_price) <= 0.000001:
            price_points += 1
        else:
            details["mismatches"].append(
                {"event_id": event_id, "field": "amount_usd", "expected": expected_price, "got": got_amount}
            )
        diagnosis_points += diagnosis_ok(str(row["correction_reason"]), issues_by_id[event_id])

    units = unit_points / len(expected_ids)
    pricing = price_points / len(expected_ids)
    diagnosis = diagnosis_points / len(expected_ids)

    report = 0.0
    try:
        report_json = json.loads(REPORT_PATH.read_text())
        report, report_mismatches = score_report(report_json, affected_ids, expected_delta)
        details["mismatches"].extend(report_mismatches)
    except Exception as exc:
        details["errors"].append(f"invalid incident_report.json: {exc}")

    component_scores = {
        "schema": schema,
        "coverage": coverage,
        "units": units,
        "pricing": pricing,
        "diagnosis": diagnosis,
        "report": report,
        "safety": safety,
    }
    reward = round(sum(WEIGHTS[name] * value for name, value in component_scores.items()), 4)
    rewards = {"reward": reward, **{name: round(value, 4) for name, value in component_scores.items()}}
    details["passed"] = reward >= 0.95
    details["rewards"] = rewards
    write_rewards(rewards, details)


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:
        zero(f"verifier crashed: {exc}")
