update
This commit is contained in:
@@ -1,41 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Print average metrics from eval.json and compare with previous run."""
|
||||
"""Print average metrics from eval.json and append to a history CSV."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def mean(values):
|
||||
return sum(values) / len(values) if values else None
|
||||
|
||||
|
||||
def parse_last_row(history_path: Path):
|
||||
def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser = argparse.ArgumentParser(description="Summarize eval.json into a history CSV.")
|
||||
parser.add_argument("--eval", dest="eval_path", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--history", default=str(base_dir / "results" / "metrics_history.csv"))
|
||||
parser.add_argument("--run-name", default="")
|
||||
parser.add_argument("--config", default="")
|
||||
parser.add_argument("--seed", type=int, default=-1)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_last_row(history_path: Path) -> Optional[dict]:
|
||||
if not history_path.exists():
|
||||
return None
|
||||
rows = history_path.read_text(encoding="utf-8").strip().splitlines()
|
||||
if len(rows) < 2:
|
||||
with history_path.open("r", encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
if not rows:
|
||||
return None
|
||||
for line in reversed(rows[1:]):
|
||||
parts = line.split(",")
|
||||
if len(parts) < 4:
|
||||
continue
|
||||
try:
|
||||
return {
|
||||
"avg_ks": float(parts[1]),
|
||||
"avg_jsd": float(parts[2]),
|
||||
"avg_lag1_diff": float(parts[3]),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
last = rows[-1]
|
||||
for key in ["avg_ks", "avg_jsd", "avg_lag1_diff"]:
|
||||
if key in last and last[key] not in [None, ""]:
|
||||
try:
|
||||
last[key] = float(last[key])
|
||||
except Exception:
|
||||
last[key] = None
|
||||
return last
|
||||
|
||||
|
||||
def ensure_header(history_path: Path, fieldnames):
|
||||
if history_path.exists():
|
||||
return
|
||||
history_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with history_path.open("w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
def main():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
eval_path = base_dir / "results" / "eval.json"
|
||||
args = parse_args()
|
||||
eval_path = Path(args.eval_path)
|
||||
if not eval_path.exists():
|
||||
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||
history_path = Path(args.history)
|
||||
|
||||
obj = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
ks = list(obj.get("continuous_ks", {}).values())
|
||||
@@ -46,22 +67,48 @@ def main():
|
||||
avg_jsd = mean(jsd)
|
||||
avg_lag1 = mean(lag)
|
||||
|
||||
history_path = base_dir / "results" / "metrics_history.csv"
|
||||
prev = parse_last_row(history_path)
|
||||
obj["avg_ks"] = avg_ks
|
||||
obj["avg_jsd"] = avg_jsd
|
||||
obj["avg_lag1_diff"] = avg_lag1
|
||||
eval_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
||||
|
||||
if not history_path.exists():
|
||||
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
|
||||
with history_path.open("a", encoding="utf-8") as f:
|
||||
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
|
||||
prev = read_last_row(history_path)
|
||||
|
||||
fieldnames = ["timestamp", "avg_ks", "avg_jsd", "avg_lag1_diff"]
|
||||
extended = any([args.run_name, args.config, args.seed >= 0])
|
||||
if extended:
|
||||
fieldnames = ["timestamp", "run_name", "config", "seed", "avg_ks", "avg_jsd", "avg_lag1_diff"]
|
||||
ensure_header(history_path, fieldnames)
|
||||
|
||||
row = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"avg_ks": avg_ks,
|
||||
"avg_jsd": avg_jsd,
|
||||
"avg_lag1_diff": avg_lag1,
|
||||
}
|
||||
if extended:
|
||||
row["run_name"] = args.run_name
|
||||
row["config"] = args.config
|
||||
row["seed"] = args.seed if args.seed >= 0 else ""
|
||||
|
||||
with history_path.open("a", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writerow(row)
|
||||
|
||||
print("avg_ks", avg_ks)
|
||||
print("avg_jsd", avg_jsd)
|
||||
print("avg_lag1_diff", avg_lag1)
|
||||
|
||||
if prev is not None:
|
||||
print("delta_avg_ks", avg_ks - prev["avg_ks"])
|
||||
print("delta_avg_jsd", avg_jsd - prev["avg_jsd"])
|
||||
print("delta_avg_lag1_diff", avg_lag1 - prev["avg_lag1_diff"])
|
||||
pks = prev.get("avg_ks")
|
||||
pjsd = prev.get("avg_jsd")
|
||||
plag = prev.get("avg_lag1_diff")
|
||||
if pks is not None:
|
||||
print("delta_avg_ks", avg_ks - pks)
|
||||
if pjsd is not None:
|
||||
print("delta_avg_jsd", avg_jsd - pjsd)
|
||||
if plag is not None:
|
||||
print("delta_avg_lag1_diff", avg_lag1 - plag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user