update ks
This commit is contained in:
@@ -34,6 +34,7 @@ def main():
|
||||
loss = []
|
||||
loss_cont = []
|
||||
loss_disc = []
|
||||
loss_quant = []
|
||||
|
||||
with log_path.open("r", encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
@@ -42,6 +43,8 @@ def main():
|
||||
loss.append(float(row["loss"]))
|
||||
loss_cont.append(float(row["loss_cont"]))
|
||||
loss_disc.append(float(row["loss_disc"]))
|
||||
if "loss_quantile" in row:
|
||||
loss_quant.append(float(row["loss_quantile"]))
|
||||
|
||||
if not steps:
|
||||
raise SystemExit("no rows in log file: %s" % log_path)
|
||||
@@ -50,6 +53,8 @@ def main():
|
||||
plt.plot(steps, loss, label="total")
|
||||
plt.plot(steps, loss_cont, label="continuous")
|
||||
plt.plot(steps, loss_disc, label="discrete")
|
||||
if loss_quant:
|
||||
plt.plot(steps, loss_quant, label="quantile")
|
||||
plt.xlabel("step")
|
||||
plt.ylabel("loss")
|
||||
plt.title("Training Loss")
|
||||
|
||||
Reference in New Issue
Block a user