#!/usr/bin/env python3 """Plot training loss curves from train_log.csv.""" import argparse import csv from pathlib import Path import matplotlib.pyplot as plt def parse_args(): parser = argparse.ArgumentParser(description="Plot loss curves from train_log.csv") base_dir = Path(__file__).resolve().parent parser.add_argument( "--log", default=str(base_dir / "results" / "train_log.csv"), help="Path to train_log.csv", ) parser.add_argument( "--out", default=str(base_dir / "results" / "train_loss.png"), help="Output PNG path", ) return parser.parse_args() def main(): args = parse_args() log_path = Path(args.log) if not log_path.exists(): raise SystemExit("missing log file: %s" % log_path) steps = [] loss = [] loss_cont = [] loss_disc = [] loss_quant = [] with log_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: steps.append(int(row["step"])) loss.append(float(row["loss"])) loss_cont.append(float(row["loss_cont"])) loss_disc.append(float(row["loss_disc"])) if "loss_quantile" in row: loss_quant.append(float(row["loss_quantile"])) if not steps: raise SystemExit("no rows in log file: %s" % log_path) plt.figure(figsize=(8, 5)) plt.plot(steps, loss, label="total") plt.plot(steps, loss_cont, label="continuous") plt.plot(steps, loss_disc, label="discrete") if loss_quant: plt.plot(steps, loss_quant, label="quantile") plt.xlabel("step") plt.ylabel("loss") plt.title("Training Loss") plt.legend() plt.tight_layout() plt.savefig(args.out, dpi=150) print("saved", args.out) if __name__ == "__main__": main()