From cc10125fbf9b6ba5dde91498bdd7ae50502c8e82 Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Sun, 25 Jan 2026 17:55:28 +0800 Subject: [PATCH] update ks --- example/config.json | 4 ++-- example/export_samples.py | 4 ++++ example/sample.py | 11 +++++++---- example/train.py | 6 +++++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/example/config.json b/example/config.json index 3f0d3e4..f812197 100644 --- a/example/config.json +++ b/example/config.json @@ -35,9 +35,9 @@ "disc_mask_scale": 0.9, "cont_loss_weighting": "inv_std", "cont_loss_eps": 1e-6, - "cont_target": "x0", + "cont_target": "v", "cont_clamp_x0": 5.0, - "quantile_loss_weight": 0.3, + "quantile_loss_weight": 0.1, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "shuffle_buffer": 256, "sample_batch_size": 8, diff --git a/example/export_samples.py b/example/export_samples.py index 79f0752..65af6b6 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -201,6 +201,10 @@ def main(): if cont_clamp_x0 > 0: x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) + elif cont_target == "v": + v_pred = eps_pred + x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred + eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean_x = coef1 * (x_cont - coef2 * eps_pred) diff --git a/example/sample.py b/example/sample.py index 9f4a6e1..2c4caf6 100755 --- a/example/sample.py +++ b/example/sample.py @@ -114,15 +114,18 @@ def main(): t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) + # Continuous reverse step (DDPM): x_{t-1} mean + a_t = alphas[t] + a_bar_t = alphas_cumprod[t] if cont_target == "x0": x0_pred = eps_pred if cont_clamp_x0 > 0: x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) - - # Continuous reverse step (DDPM): x_{t-1} mean - a_t = alphas[t] - a_bar_t = alphas_cumprod[t] + elif cont_target == "v": + v_pred = eps_pred + x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred + eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean = coef1 * (x_cont - coef2 * eps_pred) diff --git a/example/train.py b/example/train.py index f234c9d..f565a9b 100755 --- a/example/train.py +++ b/example/train.py @@ -62,7 +62,7 @@ DEFAULTS = { "shuffle_buffer": 256, "cont_loss_weighting": "none", # none | inv_std "cont_loss_eps": 1e-6, - "cont_target": "eps", # eps | x0 + "cont_target": "eps", # eps | x0 | v "cont_clamp_x0": 0.0, "quantile_loss_weight": 0.0, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], @@ -259,6 +259,10 @@ def main(): if float(config.get("cont_clamp_x0", 0.0)) > 0: x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) loss_base = (eps_pred - x0_target) ** 2 + elif cont_target == "v": + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont + loss_base = (eps_pred - v_target) ** 2 else: loss_base = (eps_pred - noise) ** 2