update新结构

This commit is contained in:
2026-01-26 19:00:16 +08:00
parent f8edee9510
commit dd4c1e171f
9 changed files with 545 additions and 5 deletions

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
@@ -61,6 +62,11 @@ DEFAULTS = {
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
@@ -204,7 +210,19 @@ def main():
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
temporal_model = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
if temporal_model is not None:
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
else:
opt_temporal = None
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
@@ -250,10 +268,20 @@ def main():
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
temporal_loss = None
x_cont_resid = x_cont
trend = None
if temporal_model is not None:
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
trend = trend.detach()
x_cont_resid = x_cont - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
@@ -268,13 +296,13 @@ def main():
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont
x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
loss_base = (eps_pred - v_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
@@ -311,7 +339,7 @@ def main():
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0.
x_real = x_cont
x_real = x_cont_resid
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
@@ -336,11 +364,18 @@ def main():
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if opt_temporal is not None:
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if ema is not None:
ema.update(model)
@@ -375,11 +410,15 @@ def main():
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__":