Use transformer for temporal trend model

This commit is contained in:
Mingzhe Yang
2026-02-04 02:40:57 +08:00
parent 84ac4cd2eb
commit 175fc684e3
6 changed files with 166 additions and 20 deletions

View File

@@ -57,9 +57,16 @@
"type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"],
"shuffle_buffer": 256,
"use_temporal_stage1": true,
"temporal_backbone": "transformer",
"temporal_hidden_dim": 384,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_pos_dim": 64,
"temporal_use_pos_embed": true,
"temporal_transformer_num_layers": 2,
"temporal_transformer_nhead": 4,
"temporal_transformer_ff_dim": 512,
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 3,
"temporal_lr": 0.001,
"quantile_loss_weight": 0.2,

View File

@@ -51,9 +51,16 @@
"full_stats": true,
"shuffle_buffer": 1024,
"use_temporal_stage1": true,
"temporal_backbone": "transformer",
"temporal_hidden_dim": 512,
"temporal_num_layers": 2,
"temporal_dropout": 0.0,
"temporal_pos_dim": 64,
"temporal_use_pos_embed": true,
"temporal_transformer_num_layers": 2,
"temporal_transformer_nhead": 4,
"temporal_transformer_ff_dim": 512,
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 5,
"temporal_lr": 0.0005,
"sample_batch_size": 4,

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -154,9 +154,16 @@ def main():
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
@@ -193,6 +200,18 @@ def main():
temporal_model = None
if use_temporal_stage1:
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,

View File

@@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module):
return torch.cat(outputs, dim=1)
class TemporalTransformerGenerator(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
num_layers: int = 2,
nhead: int = 4,
ff_dim: int = 512,
dropout: float = 0.1,
pos_dim: int = 64,
use_pos_embed: bool = True,
):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.in_proj = nn.Linear(input_dim, hidden_dim)
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=nhead,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
activation="gelu",
)
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out = self._encode(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
outputs = []
for _ in range(seq_len):
out = self._encode(context)
next_token = self.out(out[:, -1, :])
outputs.append(next_token.unsqueeze(1))
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
return torch.cat(outputs, dim=1)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
feat = self.in_proj(x)
if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0:
pos = self._positional_encoding(x.size(1), self.pos_dim, x.device)
pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1)
feat = feat + pos
mask = self._causal_mask(x.size(1), x.device)
return self.backbone(feat, mask=mask)
@staticmethod
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
@staticmethod
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe
class HybridDiffusionModel(nn.Module):
def __init__(
self,

View File

@@ -10,7 +10,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -48,9 +48,16 @@ def main():
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
@@ -108,6 +115,18 @@ def main():
temporal_model = None
if use_temporal_stage1:
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
).to(DEVICE)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,

View File

@@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
TemporalTransformerGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
@@ -66,9 +67,16 @@ DEFAULTS = {
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
"use_temporal_stage1": True,
"temporal_backbone": "gru",
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_pos_dim": 64,
"temporal_use_pos_embed": True,
"temporal_transformer_num_layers": 2,
"temporal_transformer_nhead": 4,
"temporal_transformer_ff_dim": 512,
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
"quantile_loss_weight": 0.0,
@@ -226,6 +234,19 @@ def main():
temporal_model = None
opt_temporal = None
if bool(config.get("use_temporal_stage1", False)):
temporal_backbone = str(config.get("temporal_backbone", "gru"))
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
nhead=int(config.get("temporal_transformer_nhead", 4)),
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
pos_dim=int(config.get("temporal_pos_dim", 64)),
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),