update
This commit is contained in:
@@ -46,6 +46,9 @@
|
|||||||
"temporal_lr": 0.001,
|
"temporal_lr": 0.001,
|
||||||
"quantile_loss_weight": 0.2,
|
"quantile_loss_weight": 0.2,
|
||||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
|
"snr_weighted_loss": true,
|
||||||
|
"snr_gamma": 1.0,
|
||||||
|
"residual_stat_weight": 0.05,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ DEFAULTS = {
|
|||||||
"temporal_lr": 1e-3,
|
"temporal_lr": 1e-3,
|
||||||
"quantile_loss_weight": 0.0,
|
"quantile_loss_weight": 0.0,
|
||||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
|
"snr_weighted_loss": True,
|
||||||
|
"snr_gamma": 1.0,
|
||||||
|
"residual_stat_weight": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -329,6 +332,13 @@ def main():
|
|||||||
loss_cont = (loss_base * weights).mean()
|
loss_cont = (loss_base * weights).mean()
|
||||||
else:
|
else:
|
||||||
loss_cont = loss_base.mean()
|
loss_cont = loss_base.mean()
|
||||||
|
|
||||||
|
if bool(config.get("snr_weighted_loss", False)):
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
|
||||||
|
gamma = float(config.get("snr_gamma", 1.0))
|
||||||
|
snr_weight = snr / (snr + gamma)
|
||||||
|
loss_cont = (loss_cont * snr_weight.mean()).mean()
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
loss_disc_count = 0
|
loss_disc_count = 0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
@@ -360,6 +370,22 @@ def main():
|
|||||||
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||||
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
||||||
loss = loss + q_weight * quantile_loss
|
loss = loss + q_weight * quantile_loss
|
||||||
|
|
||||||
|
stat_weight = float(config.get("residual_stat_weight", 0.0))
|
||||||
|
if stat_weight > 0:
|
||||||
|
# residual distribution matching (mean/std)
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
if cont_target == "x0":
|
||||||
|
x_gen = eps_pred
|
||||||
|
else:
|
||||||
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||||
|
x_real = x_cont_resid
|
||||||
|
mean_real = x_real.mean(dim=(0, 1))
|
||||||
|
mean_gen = x_gen.mean(dim=(0, 1))
|
||||||
|
std_real = x_real.std(dim=(0, 1))
|
||||||
|
std_gen = x_gen.std(dim=(0, 1))
|
||||||
|
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
|
||||||
|
loss = loss + stat_weight * stat_loss
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if float(config.get("grad_clip", 0.0)) > 0:
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user