Use transformer for temporal trend model
This commit is contained in:
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||
|
||||
|
||||
@@ -154,9 +154,16 @@ def main():
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
@@ -193,12 +200,24 @@ def main():
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(device)
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_transformer_num_layers,
|
||||
nhead=temporal_transformer_nhead,
|
||||
ff_dim=temporal_transformer_ff_dim,
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(device)
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
|
||||
Reference in New Issue
Block a user