#!/usr/bin/env python3 """Evaluate generated samples against simple stats and vocab validity.""" import argparse import csv import gzip import json import math from pathlib import Path from typing import Dict, Tuple, List, Optional def load_json(path: str) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def open_csv(path: str): if path.endswith(".gz"): return gzip.open(path, "rt", newline="") return open(path, "r", newline="") def parse_args(): parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.") base_dir = Path(__file__).resolve().parent parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--split", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--out", default=str(base_dir / "results" / "eval.json")) parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics") parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics") return parser.parse_args() def init_stats(cols): return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols} def update_stats(stats, col, value): st = stats[col] st["count"] += 1 delta = value - st["mean"] st["mean"] += delta / st["count"] delta2 = value - st["mean"] st["m2"] += delta * delta2 def finalize_stats(stats): out = {} for c, st in stats.items(): if st["count"] > 1: var = st["m2"] / (st["count"] - 1) else: var = 0.0 out[c] = {"mean": st["mean"], "std": var ** 0.5} return out def js_divergence(p, q, eps: float = 1e-12) -> float: p = [max(x, eps) for x in p] q = [max(x, eps) for x in q] m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)] def kl(a, b): return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b)) return 0.5 * kl(p, m) + 0.5 * kl(q, m) def ks_statistic(x: List[float], y: List[float]) -> float: if not x or not y: return 0.0 x_sorted = sorted(x) y_sorted = sorted(y) n = len(x_sorted) m = len(y_sorted) i = j = 0 cdf_x = cdf_y = 0.0 d = 0.0 while i < n and j < m: if x_sorted[i] <= y_sorted[j]: i += 1 cdf_x = i / n else: j += 1 cdf_y = j / m d = max(d, abs(cdf_x - cdf_y)) return d def lag1_corr(values: List[float]) -> float: if len(values) < 3: return 0.0 x = values[:-1] y = values[1:] mean_x = sum(x) / len(x) mean_y = sum(y) / len(y) num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y)) den_x = sum((xi - mean_x) ** 2 for xi in x) den_y = sum((yi - mean_y) ** 2 for yi in y) if den_x <= 0 or den_y <= 0: return 0.0 return num / math.sqrt(den_x * den_y) def resolve_reference_path(path: str) -> Optional[str]: if not path: return None if any(ch in path for ch in ["*", "?", "["]): base = Path(path).parent pat = Path(path).name matches = sorted(base.glob(pat)) return str(matches[0]) if matches else None return str(path) def main(): args = parse_args() base_dir = Path(__file__).resolve().parent args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out if args.reference and not Path(args.reference).is_absolute(): if any(ch in args.reference for ch in ["*", "?", "["]): args.reference = str(base_dir / args.reference) else: args.reference = str((base_dir / args.reference).resolve()) ref_path = resolve_reference_path(args.reference) split = load_json(args.split) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats_json = load_json(args.stats) stats_ref = stats_json.get("raw_mean", stats_json.get("mean")) std_ref = stats_json.get("raw_std", stats_json.get("std")) transforms = stats_json.get("transform", {}) vocab = load_json(args.vocab)["vocab"] vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols} cont_stats = init_stats(cont_cols) disc_invalid = {c: 0 for c in disc_cols} rows = 0 with open_csv(args.generated) as f: reader = csv.DictReader(f) for row in reader: rows += 1 if time_col in row: row.pop(time_col, None) for c in cont_cols: try: v = float(row[c]) except Exception: v = 0.0 update_stats(cont_stats, c, v) if ref_path: pass for c in disc_cols: if row[c] not in vocab_sets[c]: disc_invalid[c] += 1 cont_summary = finalize_stats(cont_stats) cont_err = {} for c in cont_cols: ref_mean = float(stats_ref[c]) ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0 gen_mean = cont_summary[c]["mean"] gen_std = cont_summary[c]["std"] cont_err[c] = { "mean_abs_err": abs(gen_mean - ref_mean), "std_abs_err": abs(gen_std - ref_std), } report = { "rows": rows, "continuous_summary": cont_summary, "continuous_error": cont_err, "discrete_invalid_counts": disc_invalid, } # Optional richer metrics using reference data if ref_path: ref_cont = {c: [] for c in cont_cols} ref_disc = {c: {} for c in disc_cols} gen_cont = {c: [] for c in cont_cols} gen_disc = {c: {} for c in disc_cols} with open_csv(args.generated) as f: reader = csv.DictReader(f) for row in reader: if time_col in row: row.pop(time_col, None) for c in cont_cols: try: gen_cont[c].append(float(row[c])) except Exception: gen_cont[c].append(0.0) for c in disc_cols: tok = row[c] gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1 with open_csv(ref_path) as f: reader = csv.DictReader(f) for i, row in enumerate(reader): if time_col in row: row.pop(time_col, None) for c in cont_cols: try: ref_cont[c].append(float(row[c])) except Exception: ref_cont[c].append(0.0) for c in disc_cols: tok = row[c] ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1 if args.max_rows and i + 1 >= args.max_rows: break # Continuous metrics: KS + quantiles + lag1 correlation cont_ks = {} cont_quant = {} cont_lag1 = {} for c in cont_cols: cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c]) ref_sorted = sorted(ref_cont[c]) gen_sorted = sorted(gen_cont[c]) qs = [0.05, 0.25, 0.5, 0.75, 0.95] def qval(arr, q): if not arr: return 0.0 idx = int(q * (len(arr) - 1)) return arr[idx] cont_quant[c] = { "q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)), "q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)), "q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)), "q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)), "q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)), } cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c])) # Discrete metrics: JSD over vocab disc_jsd = {} for c in disc_cols: vocab_vals = list(vocab_sets[c]) gen_total = sum(gen_disc[c].values()) or 1 ref_total = sum(ref_disc[c].values()) or 1 p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals] q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals] disc_jsd[c] = js_divergence(p, q) report["continuous_ks"] = cont_ks report["continuous_quantile_diff"] = cont_quant report["continuous_lag1_diff"] = cont_lag1 report["discrete_jsd"] = disc_jsd with open(args.out, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) print("eval_report", args.out) if __name__ == "__main__": main()