285 lines
9.6 KiB
Python
285 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Evaluate generated samples against simple stats and vocab validity."""
|
|
|
|
import argparse
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple, List, Optional
|
|
|
|
|
|
def load_json(path: str) -> Dict:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def open_csv(path: str):
|
|
if path.endswith(".gz"):
|
|
return gzip.open(path, "rt", newline="")
|
|
return open(path, "r", newline="")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
|
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
|
|
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
|
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
|
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
|
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
|
|
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
|
|
return parser.parse_args()
|
|
|
|
|
|
def init_stats(cols):
|
|
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
|
|
|
|
|
|
def update_stats(stats, col, value):
|
|
st = stats[col]
|
|
st["count"] += 1
|
|
delta = value - st["mean"]
|
|
st["mean"] += delta / st["count"]
|
|
delta2 = value - st["mean"]
|
|
st["m2"] += delta * delta2
|
|
|
|
|
|
def finalize_stats(stats):
|
|
out = {}
|
|
for c, st in stats.items():
|
|
if st["count"] > 1:
|
|
var = st["m2"] / (st["count"] - 1)
|
|
else:
|
|
var = 0.0
|
|
out[c] = {"mean": st["mean"], "std": var ** 0.5}
|
|
return out
|
|
|
|
|
|
def js_divergence(p, q, eps: float = 1e-12) -> float:
|
|
p = [max(x, eps) for x in p]
|
|
q = [max(x, eps) for x in q]
|
|
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
|
|
def kl(a, b):
|
|
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
|
|
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
|
|
|
|
|
|
def ks_statistic(x: List[float], y: List[float]) -> float:
|
|
if not x or not y:
|
|
return 0.0
|
|
x_sorted = sorted(x)
|
|
y_sorted = sorted(y)
|
|
n = len(x_sorted)
|
|
m = len(y_sorted)
|
|
i = j = 0
|
|
d = 0.0
|
|
# Iterate over merged unique values to handle ties correctly
|
|
merged = sorted(set(x_sorted) | set(y_sorted))
|
|
for v in merged:
|
|
while i < n and x_sorted[i] <= v:
|
|
i += 1
|
|
while j < m and y_sorted[j] <= v:
|
|
j += 1
|
|
cdf_x = i / n
|
|
cdf_y = j / m
|
|
d = max(d, abs(cdf_x - cdf_y))
|
|
return d
|
|
|
|
|
|
def lag1_corr(values: List[float]) -> float:
|
|
if len(values) < 3:
|
|
return 0.0
|
|
x = values[:-1]
|
|
y = values[1:]
|
|
mean_x = sum(x) / len(x)
|
|
mean_y = sum(y) / len(y)
|
|
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
|
|
den_x = sum((xi - mean_x) ** 2 for xi in x)
|
|
den_y = sum((yi - mean_y) ** 2 for yi in y)
|
|
if den_x <= 0 or den_y <= 0:
|
|
return 0.0
|
|
return num / math.sqrt(den_x * den_y)
|
|
|
|
|
|
def resolve_reference_paths(path: str) -> List[str]:
|
|
if not path:
|
|
return []
|
|
if any(ch in path for ch in ["*", "?", "["]):
|
|
base = Path(path).parent.resolve()
|
|
pat = Path(path).name
|
|
matches = sorted(base.glob(pat))
|
|
return [str(p) for p in matches]
|
|
return [str(path)]
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_dir = Path(__file__).resolve().parent
|
|
|
|
def resolve_file(p: str) -> str:
|
|
path = Path(p)
|
|
if path.is_absolute():
|
|
return str(path)
|
|
if path.exists():
|
|
return str(path.resolve())
|
|
candidate = base_dir / path
|
|
if candidate.exists():
|
|
return str(candidate.resolve())
|
|
return str((base_dir / path).resolve())
|
|
|
|
args.generated = resolve_file(args.generated)
|
|
args.split = resolve_file(args.split)
|
|
args.stats = resolve_file(args.stats)
|
|
args.vocab = resolve_file(args.vocab)
|
|
args.out = resolve_file(args.out)
|
|
if args.reference and not Path(args.reference).is_absolute():
|
|
if any(ch in args.reference for ch in ["*", "?", "["]):
|
|
args.reference = str(base_dir / args.reference)
|
|
else:
|
|
args.reference = str((base_dir / args.reference).resolve())
|
|
ref_paths = resolve_reference_paths(args.reference)
|
|
split = load_json(args.split)
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
stats_json = load_json(args.stats)
|
|
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
|
|
std_ref = stats_json.get("raw_std", stats_json.get("std"))
|
|
transforms = stats_json.get("transform", {})
|
|
vocab = load_json(args.vocab)["vocab"]
|
|
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
|
|
|
cont_stats = init_stats(cont_cols)
|
|
disc_invalid = {c: 0 for c in disc_cols}
|
|
rows = 0
|
|
|
|
with open_csv(args.generated) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
rows += 1
|
|
if time_col in row:
|
|
row.pop(time_col, None)
|
|
for c in cont_cols:
|
|
try:
|
|
v = float(row[c])
|
|
except Exception:
|
|
v = 0.0
|
|
update_stats(cont_stats, c, v)
|
|
if ref_paths:
|
|
pass
|
|
for c in disc_cols:
|
|
if row[c] not in vocab_sets[c]:
|
|
disc_invalid[c] += 1
|
|
|
|
cont_summary = finalize_stats(cont_stats)
|
|
cont_err = {}
|
|
for c in cont_cols:
|
|
ref_mean = float(stats_ref[c])
|
|
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
|
|
gen_mean = cont_summary[c]["mean"]
|
|
gen_std = cont_summary[c]["std"]
|
|
cont_err[c] = {
|
|
"mean_abs_err": abs(gen_mean - ref_mean),
|
|
"std_abs_err": abs(gen_std - ref_std),
|
|
}
|
|
|
|
report = {
|
|
"rows": rows,
|
|
"continuous_summary": cont_summary,
|
|
"continuous_error": cont_err,
|
|
"discrete_invalid_counts": disc_invalid,
|
|
}
|
|
|
|
# Optional richer metrics using reference data
|
|
if ref_paths:
|
|
ref_cont = {c: [] for c in cont_cols}
|
|
ref_disc = {c: {} for c in disc_cols}
|
|
gen_cont = {c: [] for c in cont_cols}
|
|
gen_disc = {c: {} for c in disc_cols}
|
|
|
|
with open_csv(args.generated) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
if time_col in row:
|
|
row.pop(time_col, None)
|
|
for c in cont_cols:
|
|
try:
|
|
gen_cont[c].append(float(row[c]))
|
|
except Exception:
|
|
gen_cont[c].append(0.0)
|
|
for c in disc_cols:
|
|
tok = row[c]
|
|
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
|
|
|
loaded = 0
|
|
for ref_path in ref_paths:
|
|
with open_csv(ref_path) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
if time_col in row:
|
|
row.pop(time_col, None)
|
|
for c in cont_cols:
|
|
try:
|
|
ref_cont[c].append(float(row[c]))
|
|
except Exception:
|
|
ref_cont[c].append(0.0)
|
|
for c in disc_cols:
|
|
tok = row[c]
|
|
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
|
loaded += 1
|
|
if args.max_rows and loaded >= args.max_rows:
|
|
break
|
|
if args.max_rows and loaded >= args.max_rows:
|
|
break
|
|
|
|
# Continuous metrics: KS + quantiles + lag1 correlation
|
|
cont_ks = {}
|
|
cont_quant = {}
|
|
cont_lag1 = {}
|
|
for c in cont_cols:
|
|
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
|
|
ref_sorted = sorted(ref_cont[c])
|
|
gen_sorted = sorted(gen_cont[c])
|
|
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
|
|
def qval(arr, q):
|
|
if not arr:
|
|
return 0.0
|
|
idx = int(q * (len(arr) - 1))
|
|
return arr[idx]
|
|
cont_quant[c] = {
|
|
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
|
|
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
|
|
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
|
|
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
|
|
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
|
|
}
|
|
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
|
|
|
|
# Discrete metrics: JSD over vocab
|
|
disc_jsd = {}
|
|
for c in disc_cols:
|
|
vocab_vals = list(vocab_sets[c])
|
|
gen_total = sum(gen_disc[c].values()) or 1
|
|
ref_total = sum(ref_disc[c].values()) or 1
|
|
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
|
|
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
|
|
disc_jsd[c] = js_divergence(p, q)
|
|
|
|
report["continuous_ks"] = cont_ks
|
|
report["continuous_quantile_diff"] = cont_quant
|
|
report["continuous_lag1_diff"] = cont_lag1
|
|
report["discrete_jsd"] = disc_jsd
|
|
|
|
with open(args.out, "w", encoding="utf-8") as f:
|
|
json.dump(report, f, indent=2)
|
|
|
|
print("eval_report", args.out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|