From 2a1a9a05c6cf4b07ff1dc5ee66d85a1bc39e07ae Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Sun, 25 Jan 2026 18:00:28 +0800 Subject: [PATCH] update ks --- example/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/example/train.py b/example/train.py index f565a9b..4cee971 100755 --- a/example/train.py +++ b/example/train.py @@ -293,12 +293,17 @@ def main(): if q_weight > 0: q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) - # Use normalized space for stable quantiles. + # Use normalized space for stable quantiles on x0. x_real = x_cont + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) if cont_target == "x0": x_gen = eps_pred + elif cont_target == "v": + v_pred = eps_pred + x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred else: - x_gen = x_cont - noise + # eps prediction + x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) x_real = x_real.view(-1, x_real.size(-1)) x_gen = x_gen.view(-1, x_gen.size(-1)) q_real = torch.quantile(x_real, q_tensor, dim=0)