Use transformer for temporal trend model
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -48,9 +48,16 @@ def main():
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||
cont_target = str(cfg.get("cont_target", "eps"))
|
||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||
@@ -108,12 +115,24 @@ def main():
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(DEVICE)
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_transformer_num_layers,
|
||||
nhead=temporal_transformer_nhead,
|
||||
ff_dim=temporal_transformer_ff_dim,
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
).to(DEVICE)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(DEVICE)
|
||||
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
|
||||
Reference in New Issue
Block a user