Files
mask-ddpm/example/summary_metrics.py
Mingzhe Yang 10c0721ee1 update
2026-02-04 03:53:17 +08:00

116 lines
3.6 KiB
Python

#!/usr/bin/env python3
"""Print average metrics from eval.json and append to a history CSV."""
import argparse
import csv
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
def mean(values):
return sum(values) / len(values) if values else None
def parse_args():
base_dir = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="Summarize eval.json into a history CSV.")
parser.add_argument("--eval", dest="eval_path", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--history", default=str(base_dir / "results" / "metrics_history.csv"))
parser.add_argument("--run-name", default="")
parser.add_argument("--config", default="")
parser.add_argument("--seed", type=int, default=-1)
return parser.parse_args()
def read_last_row(history_path: Path) -> Optional[dict]:
if not history_path.exists():
return None
with history_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
if not rows:
return None
last = rows[-1]
for key in ["avg_ks", "avg_jsd", "avg_lag1_diff"]:
if key in last and last[key] not in [None, ""]:
try:
last[key] = float(last[key])
except Exception:
last[key] = None
return last
def ensure_header(history_path: Path, fieldnames):
if history_path.exists():
return
history_path.parent.mkdir(parents=True, exist_ok=True)
with history_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
def main():
args = parse_args()
eval_path = Path(args.eval_path)
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
history_path = Path(args.history)
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
obj["avg_ks"] = avg_ks
obj["avg_jsd"] = avg_jsd
obj["avg_lag1_diff"] = avg_lag1
eval_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
prev = read_last_row(history_path)
fieldnames = ["timestamp", "avg_ks", "avg_jsd", "avg_lag1_diff"]
extended = any([args.run_name, args.config, args.seed >= 0])
if extended:
fieldnames = ["timestamp", "run_name", "config", "seed", "avg_ks", "avg_jsd", "avg_lag1_diff"]
ensure_header(history_path, fieldnames)
row = {
"timestamp": datetime.utcnow().isoformat(),
"avg_ks": avg_ks,
"avg_jsd": avg_jsd,
"avg_lag1_diff": avg_lag1,
}
if extended:
row["run_name"] = args.run_name
row["config"] = args.config
row["seed"] = args.seed if args.seed >= 0 else ""
with history_path.open("a", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writerow(row)
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
if prev is not None:
pks = prev.get("avg_ks")
pjsd = prev.get("avg_jsd")
plag = prev.get("avg_lag1_diff")
if pks is not None:
print("delta_avg_ks", avg_ks - pks)
if pjsd is not None:
print("delta_avg_jsd", avg_jsd - pjsd)
if plag is not None:
print("delta_avg_lag1_diff", avg_lag1 - plag)
if __name__ == "__main__":
main()