diff --git a/example/config.json b/example/config.json index 779861a..d097c38 100644 --- a/example/config.json +++ b/example/config.json @@ -46,6 +46,9 @@ "temporal_lr": 0.001, "quantile_loss_weight": 0.2, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], + "snr_weighted_loss": true, + "snr_gamma": 1.0, + "residual_stat_weight": 0.05, "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/train.py b/example/train.py index 9e0fbad..9f178c0 100755 --- a/example/train.py +++ b/example/train.py @@ -73,6 +73,9 @@ DEFAULTS = { "temporal_lr": 1e-3, "quantile_loss_weight": 0.0, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], + "snr_weighted_loss": True, + "snr_gamma": 1.0, + "residual_stat_weight": 0.0, } @@ -329,6 +332,13 @@ def main(): loss_cont = (loss_base * weights).mean() else: loss_cont = loss_base.mean() + + if bool(config.get("snr_weighted_loss", False)): + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8) + gamma = float(config.get("snr_gamma", 1.0)) + snr_weight = snr / (snr + gamma) + loss_cont = (loss_cont * snr_weight.mean()).mean() loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits): @@ -360,6 +370,22 @@ def main(): q_gen = torch.quantile(x_gen, q_tensor, dim=0) quantile_loss = torch.mean(torch.abs(q_gen - q_real)) loss = loss + q_weight * quantile_loss + + stat_weight = float(config.get("residual_stat_weight", 0.0)) + if stat_weight > 0: + # residual distribution matching (mean/std) + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + if cont_target == "x0": + x_gen = eps_pred + else: + x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) + x_real = x_cont_resid + mean_real = x_real.mean(dim=(0, 1)) + mean_gen = x_gen.mean(dim=(0, 1)) + std_real = x_real.std(dim=(0, 1)) + std_gen = x_gen.std(dim=(0, 1)) + stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real) + loss = loss + stat_weight * stat_loss opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: