update ks

This commit is contained in:
2026-01-25 18:13:37 +08:00
parent b3c45010a4
commit bc838d7cd7
3 changed files with 47 additions and 6 deletions

View File

@@ -34,6 +34,7 @@ def main():
loss = []
loss_cont = []
loss_disc = []
loss_quant = []
with log_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
@@ -42,6 +43,8 @@ def main():
loss.append(float(row["loss"]))
loss_cont.append(float(row["loss_cont"]))
loss_disc.append(float(row["loss_disc"]))
if "loss_quantile" in row:
loss_quant.append(float(row["loss_quantile"]))
if not steps:
raise SystemExit("no rows in log file: %s" % log_path)
@@ -50,6 +53,8 @@ def main():
plt.plot(steps, loss, label="total")
plt.plot(steps, loss_cont, label="continuous")
plt.plot(steps, loss_disc, label="discrete")
if loss_quant:
plt.plot(steps, loss_quant, label="quantile")
plt.xlabel("step")
plt.ylabel("loss")
plt.title("Training Loss")