back
This commit is contained in:
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
@@ -64,6 +65,12 @@ DEFAULTS = {
|
||||
"cont_loss_eps": 1e-6,
|
||||
"cont_target": "eps", # eps | x0
|
||||
"cont_clamp_x0": 0.0,
|
||||
"use_temporal_stage1": True,
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_epochs": 2,
|
||||
"temporal_lr": 1e-3,
|
||||
}
|
||||
|
||||
|
||||
@@ -194,6 +201,19 @@ def main():
|
||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
temporal_model = None
|
||||
opt_temporal = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
).to(device)
|
||||
opt_temporal = torch.optim.Adam(
|
||||
temporal_model.parameters(),
|
||||
lr=float(config.get("temporal_lr", config["lr"])),
|
||||
)
|
||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
@@ -208,6 +228,37 @@ def main():
|
||||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
if temporal_model is not None and opt_temporal is not None:
|
||||
for epoch in range(int(config.get("temporal_epochs", 1))):
|
||||
for step, batch in enumerate(
|
||||
windowed_batches(
|
||||
data_paths,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
mean,
|
||||
std,
|
||||
batch_size=int(config["batch_size"]),
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=False,
|
||||
transforms=transforms,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
x_cont, _ = batch
|
||||
x_cont = x_cont.to(device)
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||
opt_temporal.step()
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
|
||||
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||
|
||||
total_step = 0
|
||||
for epoch in range(int(config["epochs"])):
|
||||
for step, batch in enumerate(
|
||||
@@ -235,10 +286,17 @@ def main():
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
temporal_model.eval()
|
||||
with torch.no_grad():
|
||||
trend, _ = temporal_model.forward_teacher(x_cont)
|
||||
x_cont_resid = x_cont if trend is None else x_cont - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
@@ -253,7 +311,7 @@ def main():
|
||||
|
||||
cont_target = str(config.get("cont_target", "eps"))
|
||||
if cont_target == "x0":
|
||||
x0_target = x_cont
|
||||
x0_target = x_cont_resid
|
||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||
loss_base = (eps_pred - x0_target) ** 2
|
||||
@@ -308,11 +366,15 @@ def main():
|
||||
}
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
if temporal_model is not None:
|
||||
ckpt["temporal"] = temporal_model.state_dict()
|
||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||
if ema is not None:
|
||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||
if temporal_model is not None:
|
||||
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user