Files
mask-ddpm/example/plot_loss.py
2026-01-25 18:13:37 +08:00

69 lines
1.8 KiB
Python

#!/usr/bin/env python3
"""Plot training loss curves from train_log.csv."""
import argparse
import csv
from pathlib import Path
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(description="Plot loss curves from train_log.csv")
base_dir = Path(__file__).resolve().parent
parser.add_argument(
"--log",
default=str(base_dir / "results" / "train_log.csv"),
help="Path to train_log.csv",
)
parser.add_argument(
"--out",
default=str(base_dir / "results" / "train_loss.png"),
help="Output PNG path",
)
return parser.parse_args()
def main():
args = parse_args()
log_path = Path(args.log)
if not log_path.exists():
raise SystemExit("missing log file: %s" % log_path)
steps = []
loss = []
loss_cont = []
loss_disc = []
loss_quant = []
with log_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
steps.append(int(row["step"]))
loss.append(float(row["loss"]))
loss_cont.append(float(row["loss_cont"]))
loss_disc.append(float(row["loss_disc"]))
if "loss_quantile" in row:
loss_quant.append(float(row["loss_quantile"]))
if not steps:
raise SystemExit("no rows in log file: %s" % log_path)
plt.figure(figsize=(8, 5))
plt.plot(steps, loss, label="total")
plt.plot(steps, loss_cont, label="continuous")
plt.plot(steps, loss_disc, label="discrete")
if loss_quant:
plt.plot(steps, loss_quant, label="quantile")
plt.xlabel("step")
plt.ylabel("loss")
plt.title("Training Loss")
plt.legend()
plt.tight_layout()
plt.savefig(args.out, dpi=150)
print("saved", args.out)
if __name__ == "__main__":
main()