diff --git a/example/config.json b/example/config.json index f812197..90ff630 100644 --- a/example/config.json +++ b/example/config.json @@ -39,6 +39,9 @@ "cont_clamp_x0": 5.0, "quantile_loss_weight": 0.1, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], + "quantile_loss_warmup_steps": 200, + "quantile_loss_clip": 6.0, + "quantile_loss_huber_delta": 1.0, "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 diff --git a/example/plot_loss.py b/example/plot_loss.py index d45a7ed..004a328 100644 --- a/example/plot_loss.py +++ b/example/plot_loss.py @@ -34,6 +34,7 @@ def main(): loss = [] loss_cont = [] loss_disc = [] + loss_quant = [] with log_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) @@ -42,6 +43,8 @@ def main(): loss.append(float(row["loss"])) loss_cont.append(float(row["loss_cont"])) loss_disc.append(float(row["loss_disc"])) + if "loss_quantile" in row: + loss_quant.append(float(row["loss_quantile"])) if not steps: raise SystemExit("no rows in log file: %s" % log_path) @@ -50,6 +53,8 @@ def main(): plt.plot(steps, loss, label="total") plt.plot(steps, loss_cont, label="continuous") plt.plot(steps, loss_disc, label="discrete") + if loss_quant: + plt.plot(steps, loss_quant, label="quantile") plt.xlabel("step") plt.ylabel("loss") plt.title("Training Loss") diff --git a/example/train.py b/example/train.py index 4cee971..524464e 100755 --- a/example/train.py +++ b/example/train.py @@ -66,6 +66,9 @@ DEFAULTS = { "cont_clamp_x0": 0.0, "quantile_loss_weight": 0.0, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], + "quantile_loss_warmup_steps": 200, + "quantile_loss_clip": 6.0, + "quantile_loss_huber_delta": 1.0, } @@ -205,8 +208,12 @@ def main(): os.makedirs(config["out_dir"], exist_ok=True) out_dir = safe_path(config["out_dir"]) log_path = os.path.join(out_dir, "train_log.csv") + use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0 with open(log_path, "w", encoding="utf-8") as f: - f.write("epoch,step,loss,loss_cont,loss_disc\n") + if use_quantile: + f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n") + else: + f.write("epoch,step,loss,loss_cont,loss_disc\n") with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: json.dump(config, f, indent=2) @@ -290,7 +297,11 @@ def main(): loss = lam * loss_cont + (1 - lam) * loss_disc q_weight = float(config.get("quantile_loss_weight", 0.0)) + quantile_loss = 0.0 if q_weight > 0: + warmup = int(config.get("quantile_loss_warmup_steps", 0)) + if warmup > 0: + q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup)) q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) # Use normalized space for stable quantiles on x0. @@ -304,11 +315,20 @@ def main(): else: # eps prediction x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) + q_clip = float(config.get("quantile_loss_clip", 0.0)) + if q_clip > 0: + x_real = torch.clamp(x_real, -q_clip, q_clip) + x_gen = torch.clamp(x_gen, -q_clip, q_clip) x_real = x_real.view(-1, x_real.size(-1)) x_gen = x_gen.view(-1, x_gen.size(-1)) q_real = torch.quantile(x_real, q_tensor, dim=0) q_gen = torch.quantile(x_gen, q_tensor, dim=0) - quantile_loss = torch.mean(torch.abs(q_gen - q_real)) + q_delta = float(config.get("quantile_loss_huber_delta", 0.0)) + q_diff = q_gen - q_real + if q_delta > 0: + quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta) + else: + quantile_loss = torch.mean(torch.abs(q_diff)) loss = loss + q_weight * quantile_loss opt.zero_grad() loss.backward() @@ -321,10 +341,23 @@ def main(): if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) with open(log_path, "a", encoding="utf-8") as f: - f.write( - "%d,%d,%.6f,%.6f,%.6f\n" - % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) - ) + if use_quantile: + f.write( + "%d,%d,%.6f,%.6f,%.6f,%.6f\n" + % ( + epoch, + step, + float(loss), + float(loss_cont), + float(loss_disc), + float(quantile_loss), + ) + ) + else: + f.write( + "%d,%d,%.6f,%.6f,%.6f\n" + % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) + ) total_step += 1 if total_step % int(config["ckpt_every"]) == 0: