Use transformer for temporal trend model
This commit is contained in:
@@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
TemporalTransformerGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
@@ -66,9 +67,16 @@ DEFAULTS = {
|
||||
"cont_target": "eps", # eps | x0
|
||||
"cont_clamp_x0": 0.0,
|
||||
"use_temporal_stage1": True,
|
||||
"temporal_backbone": "gru",
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_pos_dim": 64,
|
||||
"temporal_use_pos_embed": True,
|
||||
"temporal_transformer_num_layers": 2,
|
||||
"temporal_transformer_nhead": 4,
|
||||
"temporal_transformer_ff_dim": 512,
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 2,
|
||||
"temporal_lr": 1e-3,
|
||||
"quantile_loss_weight": 0.0,
|
||||
@@ -226,12 +234,25 @@ def main():
|
||||
temporal_model = None
|
||||
opt_temporal = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
).to(device)
|
||||
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
||||
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
||||
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
||||
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
).to(device)
|
||||
opt_temporal = torch.optim.Adam(
|
||||
temporal_model.parameters(),
|
||||
lr=float(config.get("temporal_lr", config["lr"])),
|
||||
|
||||
Reference in New Issue
Block a user