update新结构
This commit is contained in:
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
@@ -61,6 +62,11 @@ DEFAULTS = {
|
||||
"model_use_feature_graph": True,
|
||||
"feature_graph_scale": 0.1,
|
||||
"feature_graph_dropout": 0.0,
|
||||
"use_temporal_stage1": True,
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_loss_weight": 1.0,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
"cont_loss_weighting": "none", # none | inv_std
|
||||
@@ -204,7 +210,19 @@ def main():
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
temporal_model = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
if temporal_model is not None:
|
||||
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
|
||||
else:
|
||||
opt_temporal = None
|
||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
@@ -250,10 +268,20 @@ def main():
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
temporal_loss = None
|
||||
x_cont_resid = x_cont
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
|
||||
trend = trend.detach()
|
||||
x_cont_resid = x_cont - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
@@ -268,13 +296,13 @@ def main():
|
||||
|
||||
cont_target = str(config.get("cont_target", "eps"))
|
||||
if cont_target == "x0":
|
||||
x0_target = x_cont
|
||||
x0_target = x_cont_resid
|
||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||
loss_base = (eps_pred - x0_target) ** 2
|
||||
elif cont_target == "v":
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
|
||||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
|
||||
loss_base = (eps_pred - v_target) ** 2
|
||||
else:
|
||||
loss_base = (eps_pred - noise) ** 2
|
||||
@@ -311,7 +339,7 @@ def main():
|
||||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||
# Use normalized space for stable quantiles on x0.
|
||||
x_real = x_cont
|
||||
x_real = x_cont_resid
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
if cont_target == "x0":
|
||||
x_gen = eps_pred
|
||||
@@ -336,11 +364,18 @@ def main():
|
||||
else:
|
||||
quantile_loss = torch.mean(torch.abs(q_diff))
|
||||
loss = loss + q_weight * quantile_loss
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||
opt.step()
|
||||
if opt_temporal is not None:
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||
opt_temporal.step()
|
||||
if ema is not None:
|
||||
ema.update(model)
|
||||
|
||||
@@ -375,11 +410,15 @@ def main():
|
||||
}
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
if temporal_model is not None:
|
||||
ckpt["temporal"] = temporal_model.state_dict()
|
||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||
if ema is not None:
|
||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||
if temporal_model is not None:
|
||||
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user