update ks
This commit is contained in:
@@ -293,12 +293,17 @@ def main():
|
|||||||
if q_weight > 0:
|
if q_weight > 0:
|
||||||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||||
# Use normalized space for stable quantiles.
|
# Use normalized space for stable quantiles on x0.
|
||||||
x_real = x_cont
|
x_real = x_cont
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x_gen = eps_pred
|
x_gen = eps_pred
|
||||||
|
elif cont_target == "v":
|
||||||
|
v_pred = eps_pred
|
||||||
|
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
else:
|
else:
|
||||||
x_gen = x_cont - noise
|
# eps prediction
|
||||||
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||||
x_real = x_real.view(-1, x_real.size(-1))
|
x_real = x_real.view(-1, x_real.size(-1))
|
||||||
x_gen = x_gen.view(-1, x_gen.size(-1))
|
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||||
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user