Files
mask-ddpm/example/evaluate_generated.py

117 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""Evaluate generated samples against simple stats and vocab validity."""
import argparse
import csv
import gzip
import json
from pathlib import Path
from typing import Dict, Tuple
def load_json(path: str) -> Dict:
with open(path, "r", encoding="ascii") as f:
return json.load(f)
def open_csv(path: str):
if path.endswith(".gz"):
return gzip.open(path, "rt", newline="")
return open(path, "r", newline="")
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
return parser.parse_args()
def init_stats(cols):
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
def update_stats(stats, col, value):
st = stats[col]
st["count"] += 1
delta = value - st["mean"]
st["mean"] += delta / st["count"]
delta2 = value - st["mean"]
st["m2"] += delta * delta2
def finalize_stats(stats):
out = {}
for c, st in stats.items():
if st["count"] > 1:
var = st["m2"] / (st["count"] - 1)
else:
var = 0.0
out[c] = {"mean": st["mean"], "std": var ** 0.5}
return out
def main():
args = parse_args()
split = load_json(args.split)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats_ref = load_json(args.stats)["mean"]
std_ref = load_json(args.stats)["std"]
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
cont_stats = init_stats(cont_cols)
disc_invalid = {c: 0 for c in disc_cols}
rows = 0
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
rows += 1
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
v = float(row[c])
except Exception:
v = 0.0
update_stats(cont_stats, c, v)
for c in disc_cols:
if row[c] not in vocab_sets[c]:
disc_invalid[c] += 1
cont_summary = finalize_stats(cont_stats)
cont_err = {}
for c in cont_cols:
ref_mean = float(stats_ref[c])
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
gen_mean = cont_summary[c]["mean"]
gen_std = cont_summary[c]["std"]
cont_err[c] = {
"mean_abs_err": abs(gen_mean - ref_mean),
"std_abs_err": abs(gen_std - ref_std),
}
report = {
"rows": rows,
"continuous_summary": cont_summary,
"continuous_error": cont_err,
"discrete_invalid_counts": disc_invalid,
}
with open(args.out, "w", encoding="ascii") as f:
json.dump(report, f, indent=2)
print("eval_report", args.out)
if __name__ == "__main__":
main()