From a870945e338de8ab152ab38380d42cd956e80dfe Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Sun, 25 Jan 2026 17:16:57 +0800 Subject: [PATCH] update ks --- example/config.json | 2 ++ example/train.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/example/config.json b/example/config.json index eefcd91..4172b8d 100644 --- a/example/config.json +++ b/example/config.json @@ -37,6 +37,8 @@ "cont_loss_eps": 1e-6, "cont_target": "x0", "cont_clamp_x0": 5.0, + "quantile_loss_weight": 0.1, + "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 diff --git a/example/train.py b/example/train.py index a348f57..f234c9d 100755 --- a/example/train.py +++ b/example/train.py @@ -64,6 +64,8 @@ DEFAULTS = { "cont_loss_eps": 1e-6, "cont_target": "eps", # eps | x0 "cont_clamp_x0": 0.0, + "quantile_loss_weight": 0.0, + "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], } @@ -282,6 +284,23 @@ def main(): lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc + + q_weight = float(config.get("quantile_loss_weight", 0.0)) + if q_weight > 0: + q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) + q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) + # Use normalized space for stable quantiles. + x_real = x_cont + if cont_target == "x0": + x_gen = eps_pred + else: + x_gen = x_cont - noise + x_real = x_real.view(-1, x_real.size(-1)) + x_gen = x_gen.view(-1, x_gen.size(-1)) + q_real = torch.quantile(x_real, q_tensor, dim=0) + q_gen = torch.quantile(x_gen, q_tensor, dim=0) + quantile_loss = torch.mean(torch.abs(q_gen - q_real)) + loss = loss + q_weight * quantile_loss opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: