Files
mask-ddpm/example/summary_metrics.py
2026-01-26 19:00:16 +08:00

41 lines
1.2 KiB
Python

#!/usr/bin/env python3
"""Print average metrics from eval.json for quick tracking."""
import json
from datetime import datetime
from pathlib import Path
def mean(values):
return sum(values) / len(values) if values else None
def main():
base_dir = Path(__file__).resolve().parent
eval_path = base_dir / "results" / "eval.json"
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
history_path = base_dir / "results" / "metrics_history.csv"
if not history_path.exists():
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
with history_path.open("a", encoding="utf-8") as f:
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
if __name__ == "__main__":
main()