69 lines
1.8 KiB
Python
69 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Plot training loss curves from train_log.csv."""
|
|
|
|
import argparse
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Plot loss curves from train_log.csv")
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser.add_argument(
|
|
"--log",
|
|
default=str(base_dir / "results" / "train_log.csv"),
|
|
help="Path to train_log.csv",
|
|
)
|
|
parser.add_argument(
|
|
"--out",
|
|
default=str(base_dir / "results" / "train_loss.png"),
|
|
help="Output PNG path",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
log_path = Path(args.log)
|
|
if not log_path.exists():
|
|
raise SystemExit("missing log file: %s" % log_path)
|
|
|
|
steps = []
|
|
loss = []
|
|
loss_cont = []
|
|
loss_disc = []
|
|
loss_quant = []
|
|
|
|
with log_path.open("r", encoding="utf-8", newline="") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
steps.append(int(row["step"]))
|
|
loss.append(float(row["loss"]))
|
|
loss_cont.append(float(row["loss_cont"]))
|
|
loss_disc.append(float(row["loss_disc"]))
|
|
if "loss_quantile" in row:
|
|
loss_quant.append(float(row["loss_quantile"]))
|
|
|
|
if not steps:
|
|
raise SystemExit("no rows in log file: %s" % log_path)
|
|
|
|
plt.figure(figsize=(8, 5))
|
|
plt.plot(steps, loss, label="total")
|
|
plt.plot(steps, loss_cont, label="continuous")
|
|
plt.plot(steps, loss_disc, label="discrete")
|
|
if loss_quant:
|
|
plt.plot(steps, loss_quant, label="quantile")
|
|
plt.xlabel("step")
|
|
plt.ylabel("loss")
|
|
plt.title("Training Loss")
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
plt.savefig(args.out, dpi=150)
|
|
print("saved", args.out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|