#!/usr/bin/env python3 """Evaluate generated samples against simple stats and vocab validity.""" import argparse import csv import gzip import json import math from pathlib import Path from typing import Dict, Tuple, List, Optional def load_json(path: str) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def open_csv(path: str): if path.endswith(".gz"): return gzip.open(path, "rt", newline="") return open(path, "r", newline="") def parse_args(): parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.") base_dir = Path(__file__).resolve().parent parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--split", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--out", default=str(base_dir / "results" / "eval.json")) parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics") parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics") return parser.parse_args() def init_stats(cols): return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols} def update_stats(stats, col, value): st = stats[col] st["count"] += 1 delta = value - st["mean"] st["mean"] += delta / st["count"] delta2 = value - st["mean"] st["m2"] += delta * delta2 def finalize_stats(stats): out = {} for c, st in stats.items(): if st["count"] > 1: var = st["m2"] / (st["count"] - 1) else: var = 0.0 out[c] = {"mean": st["mean"], "std": var ** 0.5} return out def js_divergence(p, q, eps: float = 1e-12) -> float: p = [max(x, eps) for x in p] q = [max(x, eps) for x in q] m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)] def kl(a, b): return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b)) return 0.5 * kl(p, m) + 0.5 * kl(q, m) def ks_statistic(x: List[float], y: List[float]) -> float: if not x or not y: return 0.0 x_sorted = sorted(x) y_sorted = sorted(y) n = len(x_sorted) m = len(y_sorted) i = j = 0 d = 0.0 # Iterate over merged unique values to handle ties correctly merged = sorted(set(x_sorted) | set(y_sorted)) for v in merged: while i < n and x_sorted[i] <= v: i += 1 while j < m and y_sorted[j] <= v: j += 1 cdf_x = i / n cdf_y = j / m d = max(d, abs(cdf_x - cdf_y)) return d def lag1_corr(values: List[float]) -> float: if len(values) < 3: return 0.0 x = values[:-1] y = values[1:] mean_x = sum(x) / len(x) mean_y = sum(y) / len(y) num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y)) den_x = sum((xi - mean_x) ** 2 for xi in x) den_y = sum((yi - mean_y) ** 2 for yi in y) if den_x <= 0 or den_y <= 0: return 0.0 return num / math.sqrt(den_x * den_y) def resolve_reference_paths(path: str) -> List[str]: if not path: return [] if any(ch in path for ch in ["*", "?", "["]): base = Path(path).parent.resolve() pat = Path(path).name matches = sorted(base.glob(pat)) return [str(p) for p in matches] return [str(path)] def main(): args = parse_args() base_dir = Path(__file__).resolve().parent def resolve_file(p: str) -> str: path = Path(p) if path.is_absolute(): return str(path) if path.exists(): return str(path.resolve()) candidate = base_dir / path if candidate.exists(): return str(candidate.resolve()) return str((base_dir / path).resolve()) args.generated = resolve_file(args.generated) args.split = resolve_file(args.split) args.stats = resolve_file(args.stats) args.vocab = resolve_file(args.vocab) args.out = resolve_file(args.out) if args.reference and not Path(args.reference).is_absolute(): if any(ch in args.reference for ch in ["*", "?", "["]): args.reference = str(base_dir / args.reference) else: args.reference = str((base_dir / args.reference).resolve()) ref_paths = resolve_reference_paths(args.reference) split = load_json(args.split) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats_json = load_json(args.stats) stats_ref = stats_json.get("raw_mean", stats_json.get("mean")) std_ref = stats_json.get("raw_std", stats_json.get("std")) transforms = stats_json.get("transform", {}) vocab = load_json(args.vocab)["vocab"] vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols} cont_stats = init_stats(cont_cols) disc_invalid = {c: 0 for c in disc_cols} rows = 0 with open_csv(args.generated) as f: reader = csv.DictReader(f) for row in reader: rows += 1 if time_col in row: row.pop(time_col, None) for c in cont_cols: try: v = float(row[c]) except Exception: v = 0.0 update_stats(cont_stats, c, v) if ref_paths: pass for c in disc_cols: if row[c] not in vocab_sets[c]: disc_invalid[c] += 1 cont_summary = finalize_stats(cont_stats) cont_err = {} for c in cont_cols: ref_mean = float(stats_ref[c]) ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0 gen_mean = cont_summary[c]["mean"] gen_std = cont_summary[c]["std"] cont_err[c] = { "mean_abs_err": abs(gen_mean - ref_mean), "std_abs_err": abs(gen_std - ref_std), } report = { "rows": rows, "continuous_summary": cont_summary, "continuous_error": cont_err, "discrete_invalid_counts": disc_invalid, } # Optional richer metrics using reference data if ref_paths: ref_cont = {c: [] for c in cont_cols} ref_disc = {c: {} for c in disc_cols} gen_cont = {c: [] for c in cont_cols} gen_disc = {c: {} for c in disc_cols} with open_csv(args.generated) as f: reader = csv.DictReader(f) for row in reader: if time_col in row: row.pop(time_col, None) for c in cont_cols: try: gen_cont[c].append(float(row[c])) except Exception: gen_cont[c].append(0.0) for c in disc_cols: tok = row[c] gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1 loaded = 0 for ref_path in ref_paths: with open_csv(ref_path) as f: reader = csv.DictReader(f) for row in reader: if time_col in row: row.pop(time_col, None) for c in cont_cols: try: ref_cont[c].append(float(row[c])) except Exception: ref_cont[c].append(0.0) for c in disc_cols: tok = row[c] ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1 loaded += 1 if args.max_rows and loaded >= args.max_rows: break if args.max_rows and loaded >= args.max_rows: break # Continuous metrics: KS + quantiles + lag1 correlation cont_ks = {} cont_quant = {} cont_lag1 = {} for c in cont_cols: cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c]) ref_sorted = sorted(ref_cont[c]) gen_sorted = sorted(gen_cont[c]) qs = [0.05, 0.25, 0.5, 0.75, 0.95] def qval(arr, q): if not arr: return 0.0 idx = int(q * (len(arr) - 1)) return arr[idx] cont_quant[c] = { "q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)), "q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)), "q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)), "q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)), "q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)), } cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c])) # Discrete metrics: JSD over vocab disc_jsd = {} for c in disc_cols: vocab_vals = list(vocab_sets[c]) gen_total = sum(gen_disc[c].values()) or 1 ref_total = sum(ref_disc[c].values()) or 1 p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals] q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals] disc_jsd[c] = js_divergence(p, q) report["continuous_ks"] = cont_ks report["continuous_quantile_diff"] = cont_quant report["continuous_lag1_diff"] = cont_lag1 report["discrete_jsd"] = disc_jsd with open(args.out, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) print("eval_report", args.out) if __name__ == "__main__": main()