117 lines
3.4 KiB
Python
117 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Evaluate generated samples against simple stats and vocab validity."""
|
|
|
|
import argparse
|
|
import csv
|
|
import gzip
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
def load_json(path: str) -> Dict:
|
|
with open(path, "r", encoding="ascii") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def open_csv(path: str):
|
|
if path.endswith(".gz"):
|
|
return gzip.open(path, "rt", newline="")
|
|
return open(path, "r", newline="")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
|
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
|
|
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
|
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
|
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
|
return parser.parse_args()
|
|
|
|
|
|
def init_stats(cols):
|
|
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
|
|
|
|
|
|
def update_stats(stats, col, value):
|
|
st = stats[col]
|
|
st["count"] += 1
|
|
delta = value - st["mean"]
|
|
st["mean"] += delta / st["count"]
|
|
delta2 = value - st["mean"]
|
|
st["m2"] += delta * delta2
|
|
|
|
|
|
def finalize_stats(stats):
|
|
out = {}
|
|
for c, st in stats.items():
|
|
if st["count"] > 1:
|
|
var = st["m2"] / (st["count"] - 1)
|
|
else:
|
|
var = 0.0
|
|
out[c] = {"mean": st["mean"], "std": var ** 0.5}
|
|
return out
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
split = load_json(args.split)
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
stats_ref = load_json(args.stats)["mean"]
|
|
std_ref = load_json(args.stats)["std"]
|
|
vocab = load_json(args.vocab)["vocab"]
|
|
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
|
|
|
cont_stats = init_stats(cont_cols)
|
|
disc_invalid = {c: 0 for c in disc_cols}
|
|
rows = 0
|
|
|
|
with open_csv(args.generated) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
rows += 1
|
|
if time_col in row:
|
|
row.pop(time_col, None)
|
|
for c in cont_cols:
|
|
try:
|
|
v = float(row[c])
|
|
except Exception:
|
|
v = 0.0
|
|
update_stats(cont_stats, c, v)
|
|
for c in disc_cols:
|
|
if row[c] not in vocab_sets[c]:
|
|
disc_invalid[c] += 1
|
|
|
|
cont_summary = finalize_stats(cont_stats)
|
|
cont_err = {}
|
|
for c in cont_cols:
|
|
ref_mean = float(stats_ref[c])
|
|
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
|
|
gen_mean = cont_summary[c]["mean"]
|
|
gen_std = cont_summary[c]["std"]
|
|
cont_err[c] = {
|
|
"mean_abs_err": abs(gen_mean - ref_mean),
|
|
"std_abs_err": abs(gen_std - ref_std),
|
|
}
|
|
|
|
report = {
|
|
"rows": rows,
|
|
"continuous_summary": cont_summary,
|
|
"continuous_error": cont_err,
|
|
"discrete_invalid_counts": disc_invalid,
|
|
}
|
|
|
|
with open(args.out, "w", encoding="ascii") as f:
|
|
json.dump(report, f, indent=2)
|
|
|
|
print("eval_report", args.out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|