Files
mask-ddpm/example/evaluate_generated.py

264 lines
9.2 KiB
Python

#!/usr/bin/env python3
"""Evaluate generated samples against simple stats and vocab validity."""
import argparse
import csv
import gzip
import json
import math
from pathlib import Path
from typing import Dict, Tuple, List, Optional
def load_json(path: str) -> Dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def open_csv(path: str):
if path.endswith(".gz"):
return gzip.open(path, "rt", newline="")
return open(path, "r", newline="")
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
return parser.parse_args()
def init_stats(cols):
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
def update_stats(stats, col, value):
st = stats[col]
st["count"] += 1
delta = value - st["mean"]
st["mean"] += delta / st["count"]
delta2 = value - st["mean"]
st["m2"] += delta * delta2
def finalize_stats(stats):
out = {}
for c, st in stats.items():
if st["count"] > 1:
var = st["m2"] / (st["count"] - 1)
else:
var = 0.0
out[c] = {"mean": st["mean"], "std": var ** 0.5}
return out
def js_divergence(p, q, eps: float = 1e-12) -> float:
p = [max(x, eps) for x in p]
q = [max(x, eps) for x in q]
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
def kl(a, b):
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
def ks_statistic(x: List[float], y: List[float]) -> float:
if not x or not y:
return 0.0
x_sorted = sorted(x)
y_sorted = sorted(y)
n = len(x_sorted)
m = len(y_sorted)
i = j = 0
cdf_x = cdf_y = 0.0
d = 0.0
while i < n and j < m:
if x_sorted[i] <= y_sorted[j]:
i += 1
cdf_x = i / n
else:
j += 1
cdf_y = j / m
d = max(d, abs(cdf_x - cdf_y))
return d
def lag1_corr(values: List[float]) -> float:
if len(values) < 3:
return 0.0
x = values[:-1]
y = values[1:]
mean_x = sum(x) / len(x)
mean_y = sum(y) / len(y)
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
den_x = sum((xi - mean_x) ** 2 for xi in x)
den_y = sum((yi - mean_y) ** 2 for yi in y)
if den_x <= 0 or den_y <= 0:
return 0.0
return num / math.sqrt(den_x * den_y)
def resolve_reference_path(path: str) -> Optional[str]:
if not path:
return None
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent
pat = Path(path).name
matches = sorted(base.glob(pat))
return str(matches[0]) if matches else None
return str(path)
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated
args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
if args.reference and not Path(args.reference).is_absolute():
args.reference = str((base_dir / args.reference).resolve())
ref_path = resolve_reference_path(args.reference)
split = load_json(args.split)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats_json = load_json(args.stats)
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
cont_stats = init_stats(cont_cols)
disc_invalid = {c: 0 for c in disc_cols}
rows = 0
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
rows += 1
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
v = float(row[c])
except Exception:
v = 0.0
update_stats(cont_stats, c, v)
if ref_path:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
disc_invalid[c] += 1
cont_summary = finalize_stats(cont_stats)
cont_err = {}
for c in cont_cols:
ref_mean = float(stats_ref[c])
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
gen_mean = cont_summary[c]["mean"]
gen_std = cont_summary[c]["std"]
cont_err[c] = {
"mean_abs_err": abs(gen_mean - ref_mean),
"std_abs_err": abs(gen_std - ref_std),
}
report = {
"rows": rows,
"continuous_summary": cont_summary,
"continuous_error": cont_err,
"discrete_invalid_counts": disc_invalid,
}
# Optional richer metrics using reference data
if ref_path:
ref_cont = {c: [] for c in cont_cols}
ref_disc = {c: {} for c in disc_cols}
gen_cont = {c: [] for c in cont_cols}
gen_disc = {c: {} for c in disc_cols}
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
gen_cont[c].append(float(row[c]))
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
if args.max_rows and i + 1 >= args.max_rows:
break
# Continuous metrics: KS + quantiles + lag1 correlation
cont_ks = {}
cont_quant = {}
cont_lag1 = {}
for c in cont_cols:
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
ref_sorted = sorted(ref_cont[c])
gen_sorted = sorted(gen_cont[c])
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
def qval(arr, q):
if not arr:
return 0.0
idx = int(q * (len(arr) - 1))
return arr[idx]
cont_quant[c] = {
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
}
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
# Discrete metrics: JSD over vocab
disc_jsd = {}
for c in disc_cols:
vocab_vals = list(vocab_sets[c])
gen_total = sum(gen_disc[c].values()) or 1
ref_total = sum(ref_disc[c].values()) or 1
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
disc_jsd[c] = js_divergence(p, q)
report["continuous_ks"] = cont_ks
report["continuous_quantile_diff"] = cont_quant
report["continuous_lag1_diff"] = cont_lag1
report["discrete_jsd"] = disc_jsd
with open(args.out, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
print("eval_report", args.out)
if __name__ == "__main__":
main()