update ks
This commit is contained in:
@@ -37,6 +37,8 @@
|
|||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "x0",
|
"cont_target": "x0",
|
||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
|
"quantile_loss_weight": 0.1,
|
||||||
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ DEFAULTS = {
|
|||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "eps", # eps | x0
|
"cont_target": "eps", # eps | x0
|
||||||
"cont_clamp_x0": 0.0,
|
"cont_clamp_x0": 0.0,
|
||||||
|
"quantile_loss_weight": 0.0,
|
||||||
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -282,6 +284,23 @@ def main():
|
|||||||
|
|
||||||
lam = float(config["lambda"])
|
lam = float(config["lambda"])
|
||||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
|
|
||||||
|
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||||||
|
if q_weight > 0:
|
||||||
|
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||||
|
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||||
|
# Use normalized space for stable quantiles.
|
||||||
|
x_real = x_cont
|
||||||
|
if cont_target == "x0":
|
||||||
|
x_gen = eps_pred
|
||||||
|
else:
|
||||||
|
x_gen = x_cont - noise
|
||||||
|
x_real = x_real.view(-1, x_real.size(-1))
|
||||||
|
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||||
|
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||||
|
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||||
|
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
||||||
|
loss = loss + q_weight * quantile_loss
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if float(config.get("grad_clip", 0.0)) > 0:
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user