Use transformer for temporal trend model

This commit is contained in:
Mingzhe Yang
2026-02-04 02:40:57 +08:00
parent 84ac4cd2eb
commit 175fc684e3
6 changed files with 166 additions and 20 deletions

View File

@@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
TemporalTransformerGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
@@ -66,9 +67,16 @@ DEFAULTS = {
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
"use_temporal_stage1": True,
"temporal_backbone": "gru",
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_pos_dim": 64,
"temporal_use_pos_embed": True,
"temporal_transformer_num_layers": 2,
"temporal_transformer_nhead": 4,
"temporal_transformer_ff_dim": 512,
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
"quantile_loss_weight": 0.0,
@@ -226,12 +234,25 @@ def main():
temporal_model = None
opt_temporal = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
temporal_backbone = str(config.get("temporal_backbone", "gru"))
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
nhead=int(config.get("temporal_transformer_nhead", 4)),
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
pos_dim=int(config.get("temporal_pos_dim", 64)),
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt_temporal = torch.optim.Adam(
temporal_model.parameters(),
lr=float(config.get("temporal_lr", config["lr"])),