#!/usr/bin/env python3
from __future__ import annotations

import argparse
import re
import sys
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np


@dataclass
class AtomRecord:
    idx: int
    element: str
    pos: Tuple[float, float, float]
    rank: int
    charge: float
    polarizability: float  # the "P <value>" line


# Example atom line:
# "  C +30.2999798 +1.2703692 +5.5205464 Rank 0"
ATOM_RE = re.compile(
    r"^\s*([A-Za-z]{1,2})\s+"
    r"([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)\s+"
    r"([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)\s+"
    r"([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)\s+"
    r"Rank\s+(\d+)\s*$"
)

# Charge line:
# "    -0.0145563"
NUM_RE = re.compile(r"^\s*([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)\s*$")

# Polarizability line:
# "     P +1.3340000"
P_RE = re.compile(
    r"^\s*P\s+([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)\s*$"
)


def parse_mps_votca(path: str) -> List[AtomRecord]:
    """
    Parses VOTCA/XTP .mps blocks:
      atom-line
      charge-line
      P-line
    Returns list of AtomRecord in file order.
    """
    with open(path, "r", encoding="utf-8", errors="replace") as f:
        lines = [ln.rstrip("\n") for ln in f]

    # Find atom lines and read triplets atom/charge/P
    records: List[AtomRecord] = []
    i = 0
    site = 0
    while i < len(lines):
        line = lines[i]

        m = ATOM_RE.match(line)
        if not m:
            i += 1
            continue

        element = m.group(1)
        x = float(m.group(2))
        y = float(m.group(3))
        z = float(m.group(4))
        rank = int(m.group(5))

        # Need next two lines for charge and P
        if i + 2 >= len(lines):
            raise ValueError(f"{path}: truncated atom block at line {i+1}")

        m_charge = NUM_RE.match(lines[i + 1])
        if not m_charge:
            raise ValueError(f"{path}: expected charge line after atom at line {i+1}, got: {lines[i+1]!r}")
        charge = float(m_charge.group(1))

        m_p = P_RE.match(lines[i + 2])
        if not m_p:
            raise ValueError(f"{path}: expected 'P <value>' line after charge at line {i+2}, got: {lines[i+2]!r}")
        pol = float(m_p.group(1))

        records.append(
            AtomRecord(
                idx=site,
                element=element,
                pos=(x, y, z),
                rank=rank,
                charge=charge,
                polarizability=pol,
            )
        )

        site += 1
        i += 3  # advance by block size

    if not records:
        raise ValueError(f"{path}: no atom blocks found (did not match expected VOTCA/XTP .mps format)")
    return records


def isclose(a: float, b: float, rtol: float, atol: float) -> bool:
    return bool(np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=True))


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("file_a")
    ap.add_argument("file_b")
    ap.add_argument("--rtol", type=float, default=1e-8)
    ap.add_argument("--atol", type=float, default=1e-12)
    ap.add_argument("--pos-atol", type=float, default=None,
                    help="Optional absolute tolerance specifically for positions (defaults to --atol).")
    ap.add_argument("--allow-charge-sign", action="store_true",
                    help="Allow a consistent global sign flip for all charges: accept A≈B or A≈-B.")
    ap.add_argument("--max-issues", type=int, default=200)
    args = ap.parse_args()

    pos_atol = args.pos_atol if args.pos_atol is not None else args.atol

    try:
        A = parse_mps_votca(args.file_a)
        B = parse_mps_votca(args.file_b)
    except Exception as e:
        print(f"[compare_mps] ERROR: {e}", file=sys.stderr)
        return 2

    issues: List[str] = []

    if len(A) != len(B):
        issues.append(f"number of atoms differs: A={len(A)} vs B={len(B)}")

    n = min(len(A), len(B))
    # Determine best global sign for charge if requested
    sign = 1.0
    if args.allow_charge_sign and n > 0:
        a_ch = np.array([rec.charge for rec in A[:n]], dtype=float)
        b_ch = np.array([rec.charge for rec in B[:n]], dtype=float)

        ok_same = np.isclose(a_ch, b_ch, rtol=args.rtol, atol=args.atol, equal_nan=True)
        ok_flip = np.isclose(a_ch, -b_ch, rtol=args.rtol, atol=args.atol, equal_nan=True)

        # Pick whichever yields fewer mismatches (and accept only if perfect later)
        mism_same = int(np.count_nonzero(~ok_same))
        mism_flip = int(np.count_nonzero(~ok_flip))
        sign = -1.0 if mism_flip < mism_same else 1.0

    # Compare per atom
    for i in range(n):
        a = A[i]
        b = B[i]

        if a.element != b.element:
            issues.append(f"[site {i}] element differs: A={a.element} vs B={b.element}")

        if a.rank != b.rank:
            issues.append(f"[site {i} {a.element}] rank differs: A={a.rank} vs B={b.rank}")

        # positions
        for c, (av, bv) in zip(("x", "y", "z"), zip(a.pos, b.pos)):
            if not isclose(av, bv, rtol=args.rtol, atol=pos_atol):
                issues.append(
                    f"[site {i} {a.element}] pos.{c} differs: A={av:.16g} vs B={bv:.16g} (rtol={args.rtol}, atol={pos_atol})"
                )

        # polarizability P
        if not isclose(a.polarizability, b.polarizability, rtol=args.rtol, atol=args.atol):
            issues.append(
                f"[site {i} {a.element}] P differs: A={a.polarizability:.16g} vs B={b.polarizability:.16g} "
                f"(rtol={args.rtol}, atol={args.atol})"
            )

        # charge (optionally sign invariant)
        b_charge = sign * b.charge if args.allow_charge_sign else b.charge
        if not isclose(a.charge, b_charge, rtol=args.rtol, atol=args.atol):
            note = " (after global sign flip)" if (args.allow_charge_sign and sign < 0) else ""
            issues.append(
                f"[site {i} {a.element}] charge differs{note}: A={a.charge:.16g} vs B={b_charge:.16g} "
                f"(rtol={args.rtol}, atol={args.atol})"
            )

        if len(issues) >= args.max_issues:
            issues.append(f"... more issues not shown (max_issues={args.max_issues})")
            break

    if issues:
        print(f"[compare_mps] FAIL: {len(issues)} issue(s) found")
        for msg in issues:
            print(" - " + msg)
        return 1

    print("[compare_mps] OK: files match (within tolerances" +
          (", global charge sign allowed" if args.allow_charge_sign else "") + ")")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
