Use transformer for temporal trend model

This commit is contained in:
Mingzhe Yang
2026-02-04 02:40:57 +08:00
parent 84ac4cd2eb
commit 175fc684e3
6 changed files with 166 additions and 20 deletions

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -154,9 +154,16 @@ def main():
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
@@ -193,12 +200,24 @@ def main():
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")