Clean artifacts and update example pipeline
This commit is contained in:
116
example/evaluate_generated.py
Normal file
116
example/evaluate_generated.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate generated samples against simple stats and vocab validity."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
def load_json(path: str) -> Dict:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def open_csv(path: str):
|
||||
if path.endswith(".gz"):
|
||||
return gzip.open(path, "rt", newline="")
|
||||
return open(path, "r", newline="")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
||||
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
|
||||
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def init_stats(cols):
|
||||
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
|
||||
|
||||
|
||||
def update_stats(stats, col, value):
|
||||
st = stats[col]
|
||||
st["count"] += 1
|
||||
delta = value - st["mean"]
|
||||
st["mean"] += delta / st["count"]
|
||||
delta2 = value - st["mean"]
|
||||
st["m2"] += delta * delta2
|
||||
|
||||
|
||||
def finalize_stats(stats):
|
||||
out = {}
|
||||
for c, st in stats.items():
|
||||
if st["count"] > 1:
|
||||
var = st["m2"] / (st["count"] - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
out[c] = {"mean": st["mean"], "std": var ** 0.5}
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
split = load_json(args.split)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats_ref = load_json(args.stats)["mean"]
|
||||
std_ref = load_json(args.stats)["std"]
|
||||
vocab = load_json(args.vocab)["vocab"]
|
||||
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||
|
||||
cont_stats = init_stats(cont_cols)
|
||||
disc_invalid = {c: 0 for c in disc_cols}
|
||||
rows = 0
|
||||
|
||||
with open_csv(args.generated) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows += 1
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
v = float(row[c])
|
||||
except Exception:
|
||||
v = 0.0
|
||||
update_stats(cont_stats, c, v)
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
disc_invalid[c] += 1
|
||||
|
||||
cont_summary = finalize_stats(cont_stats)
|
||||
cont_err = {}
|
||||
for c in cont_cols:
|
||||
ref_mean = float(stats_ref[c])
|
||||
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
|
||||
gen_mean = cont_summary[c]["mean"]
|
||||
gen_std = cont_summary[c]["std"]
|
||||
cont_err[c] = {
|
||||
"mean_abs_err": abs(gen_mean - ref_mean),
|
||||
"std_abs_err": abs(gen_std - ref_std),
|
||||
}
|
||||
|
||||
report = {
|
||||
"rows": rows,
|
||||
"continuous_summary": cont_summary,
|
||||
"continuous_error": cont_err,
|
||||
"discrete_invalid_counts": disc_invalid,
|
||||
}
|
||||
|
||||
with open(args.out, "w", encoding="ascii") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
print("eval_report", args.out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user