update ks
This commit is contained in:
@@ -35,9 +35,9 @@
|
|||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"cont_loss_weighting": "inv_std",
|
"cont_loss_weighting": "inv_std",
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "x0",
|
"cont_target": "v",
|
||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
"quantile_loss_weight": 0.3,
|
"quantile_loss_weight": 0.1,
|
||||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
|
|||||||
@@ -201,6 +201,10 @@ def main():
|
|||||||
if cont_clamp_x0 > 0:
|
if cont_clamp_x0 > 0:
|
||||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
elif cont_target == "v":
|
||||||
|
v_pred = eps_pred
|
||||||
|
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
|
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
|||||||
@@ -114,15 +114,18 @@ def main():
|
|||||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x0_pred = eps_pred
|
x0_pred = eps_pred
|
||||||
if cont_clamp_x0 > 0:
|
if cont_clamp_x0 > 0:
|
||||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
elif cont_target == "v":
|
||||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
v_pred = eps_pred
|
||||||
a_t = alphas[t]
|
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
a_bar_t = alphas_cumprod[t]
|
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
mean = coef1 * (x_cont - coef2 * eps_pred)
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ DEFAULTS = {
|
|||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"cont_loss_weighting": "none", # none | inv_std
|
"cont_loss_weighting": "none", # none | inv_std
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "eps", # eps | x0
|
"cont_target": "eps", # eps | x0 | v
|
||||||
"cont_clamp_x0": 0.0,
|
"cont_clamp_x0": 0.0,
|
||||||
"quantile_loss_weight": 0.0,
|
"quantile_loss_weight": 0.0,
|
||||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
@@ -259,6 +259,10 @@ def main():
|
|||||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||||
loss_base = (eps_pred - x0_target) ** 2
|
loss_base = (eps_pred - x0_target) ** 2
|
||||||
|
elif cont_target == "v":
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
|
||||||
|
loss_base = (eps_pred - v_target) ** 2
|
||||||
else:
|
else:
|
||||||
loss_base = (eps_pred - noise) ** 2
|
loss_base = (eps_pred - noise) ** 2
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user