#!/usr/bin/env python3
"""
compare_xml_numeric.py — semantic XML compare with:
- numeric tolerances (rtol/atol) for text and attributes
- ignore tags at any depth (skip subtree)
- per-tag tolerance overrides (e.g. qp_energy needs looser tolerance)
- per-tag "allow global sign flip" for numeric vectors in text (e.g. Trdipole)

Examples:
  python compare_xml_numeric.py ref.xml out.xml \
    --ignore-tag Compute_Time --ignore-tag host --ignore-tag time \
    --tag-tol qp_energy:1e-7:1e-6 \
    --signflip-tag Trdipole

Format for --tag-tol:
  TAG:RTOL:ATOL
"""

from __future__ import annotations

import argparse
import math
import re
import sys
from typing import Optional, Set, List, Dict, Tuple

from lxml import etree

NUM_RE = re.compile(r"""
    ^[+-]?(
      (\d+(\.\d*)?)|(\.\d+)
    )([eE][+-]?\d+)?$
""", re.VERBOSE)


def norm_text(s: Optional[str]) -> str:
    return "" if s is None else s.strip()


def is_number(s: str) -> bool:
    return bool(NUM_RE.match(s.strip()))


def to_float(s: str) -> float:
    return float(s.strip())


def num_equal(a: float, b: float, atol: float, rtol: float) -> bool:
    return math.isclose(a, b, abs_tol=atol, rel_tol=rtol)


def parse_tag_tol(items: List[str]) -> Dict[str, Tuple[float, float]]:
    out: Dict[str, Tuple[float, float]] = {}
    for it in items:
        # TAG:RTOL:ATOL
        parts = it.split(":")
        if len(parts) != 3:
            raise ValueError(f"--tag-tol expects TAG:RTOL:ATOL, got {it!r}")
        tag, rtol_s, atol_s = parts
        out[tag] = (float(rtol_s), float(atol_s))
    return out


def pick_tols(tag: str, default_rtol: float, default_atol: float,
              tag_tols: Dict[str, Tuple[float, float]]) -> Tuple[float, float]:
    return tag_tols.get(tag, (default_rtol, default_atol))


def parse_float_vector(s: str) -> Optional[List[float]]:
    """
    Parse a whitespace-separated list of floats.
    Returns None if any token is not a float.
    """
    s = norm_text(s)
    if not s:
        return []
    parts = s.split()
    vals: List[float] = []
    for p in parts:
        if not is_number(p):
            return None
        vals.append(float(p))
    return vals


def vec_equal_allow_signflip(a: List[float], b: List[float], atol: float, rtol: float) -> bool:
    if len(a) != len(b):
        return False
    # v ~ w
    ok1 = all(math.isclose(x, y, abs_tol=atol, rel_tol=rtol) for x, y in zip(a, b))
    if ok1:
        return True
    # v ~ -w
    ok2 = all(math.isclose(x, -y, abs_tol=atol, rel_tol=rtol) for x, y in zip(a, b))
    return ok2


def compare_values(path: str, a: Optional[str], b: Optional[str],
                   rtol: float, atol: float, diffs: List[str],
                   allow_signflip_vector: bool = False) -> None:
    sa = norm_text(a)
    sb = norm_text(b)

    if sa == sb:
        return

    # Special: numeric vectors where overall sign is arbitrary
    if allow_signflip_vector:
        va = parse_float_vector(sa)
        vb = parse_float_vector(sb)
        if va is not None and vb is not None:
            if vec_equal_allow_signflip(va, vb, atol=atol, rtol=rtol):
                return
            diffs.append(
                f"{path}: numeric-vector mismatch (± allowed) {va} vs {vb} (atol={atol}, rtol={rtol})"
            )
            return
        # If not parseable as a vector, fall back to normal rules

    # Scalar numeric tolerance
    if is_number(sa) and is_number(sb):
        fa, fb = to_float(sa), to_float(sb)
        if not num_equal(fa, fb, atol, rtol):
            diffs.append(f"{path}: numeric mismatch {fa} vs {fb} (atol={atol}, rtol={rtol})")
    else:
        diffs.append(f"{path}: text mismatch {sa!r} vs {sb!r}")


def compare_elements(
    ea: etree._Element,
    eb: etree._Element,
    path: str,
    default_atol: float,
    default_rtol: float,
    diffs: List[str],
    ignore_tags: Set[str],
    tag_tols: Dict[str, Tuple[float, float]],
    signflip_tags: Set[str],
) -> None:
    # Ignore subtree if either side is ignored tag
    if ea.tag in ignore_tags or eb.tag in ignore_tags:
        return

    if ea.tag != eb.tag:
        diffs.append(f"{path}: tag differs {ea.tag!r} vs {eb.tag!r}")
        return

    # pick tolerances for this element tag
    rtol, atol = pick_tols(ea.tag, default_rtol, default_atol, tag_tols)
    allow_sf = ea.tag in signflip_tags

    # Attributes (order-independent)
    a_keys = set(ea.attrib.keys())
    b_keys = set(eb.attrib.keys())

    for k in sorted(a_keys - b_keys):
        diffs.append(f"{path}: attribute {k!r} missing in B")
    for k in sorted(b_keys - a_keys):
        diffs.append(f"{path}: attribute {k!r} missing in A")

    for k in sorted(a_keys & b_keys):
        # attributes are usually scalars; no signflip by default
        compare_values(f"{path}/@{k}", ea.attrib.get(k), eb.attrib.get(k), rtol, atol, diffs, allow_signflip_vector=False)

    # Text content (apply signflip rule if configured for this tag)
    compare_values(path, ea.text, eb.text, rtol, atol, diffs, allow_signflip_vector=allow_sf)

    # Children (skip ignored tags)
    ca = [c for c in list(ea) if c.tag not in ignore_tags]
    cb = [c for c in list(eb) if c.tag not in ignore_tags]

    if len(ca) != len(cb):
        diffs.append(f"{path}: number of (non-ignored) children differs {len(ca)} vs {len(cb)}")
        return

    for i, (xa, xb) in enumerate(zip(ca, cb)):
        compare_elements(
            xa, xb,
            f"{path}/{xa.tag}[{i}]",
            default_atol, default_rtol,
            diffs,
            ignore_tags,
            tag_tols,
            signflip_tags
        )

    # Tail text (usually whitespace)
    compare_values(f"{path}#tail", ea.tail, eb.tail, rtol, atol, diffs, allow_signflip_vector=False)


def load_xml(path: str) -> etree._Element:
    parser = etree.XMLParser(remove_blank_text=True, resolve_entities=False, no_network=True)
    return etree.parse(path, parser).getroot()


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("ref")
    ap.add_argument("out")
    ap.add_argument("--atol", type=float, default=1e-12)
    ap.add_argument("--rtol", type=float, default=1e-8)

    ap.add_argument("--ignore-tag", action="append", default=[],
                    help="Tag name to ignore completely at any depth (repeatable)")

    ap.add_argument("--tag-tol", action="append", default=[],
                    help="Per-tag tolerance override: TAG:RTOL:ATOL (repeatable), e.g. qp_energy:1e-7:1e-6")

    ap.add_argument("--signflip-tag", action="append", default=[],
                    help="For these tags, allow numeric vectors in text to match up to an overall sign flip (repeatable)")

    ap.add_argument("--max-report", type=int, default=50)
    args = ap.parse_args()

    ignore_tags = set(args.ignore_tag)
    signflip_tags = set(args.signflip_tag)

    try:
        tag_tols = parse_tag_tol(args.tag_tol)
    except Exception as e:
        print(f"[compare_xml] bad --tag-tol: {e}", file=sys.stderr)
        return 2

    try:
        ra = load_xml(args.ref)
        rb = load_xml(args.out)
    except Exception as e:
        print(f"[compare_xml] parse error: {e}", file=sys.stderr)
        return 2

    diffs: List[str] = []
    compare_elements(
        ra, rb, f"/{ra.tag}",
        default_atol=args.atol,
        default_rtol=args.rtol,
        diffs=diffs,
        ignore_tags=ignore_tags,
        tag_tols=tag_tols,
        signflip_tags=signflip_tags,
    )

    if diffs:
        print(f"[compare_xml] FAIL: {len(diffs)} differences found")
        for d in diffs[:args.max_report]:
            print(" -", d)
        if len(diffs) > args.max_report:
            print(f" ... {len(diffs) - args.max_report} more differences")
        return 1

    print("[compare_xml] OK (numeric tolerance + ignore tags + per-tag tolerances + sign-flip tags applied)")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
