From 175fc684e31aac866c7f945e3564166d70f35ab1 Mon Sep 17 00:00:00 2001 From: Mingzhe Yang Date: Wed, 4 Feb 2026 02:40:57 +0800 Subject: [PATCH] Use transformer for temporal trend model --- example/config.json | 7 +++ example/config_temporal_strong.json | 7 +++ example/export_samples.py | 33 ++++++++++--- example/hybrid_diffusion.py | 73 +++++++++++++++++++++++++++++ example/sample.py | 33 ++++++++++--- example/train.py | 33 ++++++++++--- 6 files changed, 166 insertions(+), 20 deletions(-) diff --git a/example/config.json b/example/config.json index d4c73be..66c0e96 100644 --- a/example/config.json +++ b/example/config.json @@ -57,9 +57,16 @@ "type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"], "shuffle_buffer": 256, "use_temporal_stage1": true, + "temporal_backbone": "transformer", "temporal_hidden_dim": 384, "temporal_num_layers": 1, "temporal_dropout": 0.0, + "temporal_pos_dim": 64, + "temporal_use_pos_embed": true, + "temporal_transformer_num_layers": 2, + "temporal_transformer_nhead": 4, + "temporal_transformer_ff_dim": 512, + "temporal_transformer_dropout": 0.1, "temporal_epochs": 3, "temporal_lr": 0.001, "quantile_loss_weight": 0.2, diff --git a/example/config_temporal_strong.json b/example/config_temporal_strong.json index b5af569..a476049 100644 --- a/example/config_temporal_strong.json +++ b/example/config_temporal_strong.json @@ -51,9 +51,16 @@ "full_stats": true, "shuffle_buffer": 1024, "use_temporal_stage1": true, + "temporal_backbone": "transformer", "temporal_hidden_dim": 512, "temporal_num_layers": 2, "temporal_dropout": 0.0, + "temporal_pos_dim": 64, + "temporal_use_pos_embed": true, + "temporal_transformer_num_layers": 2, + "temporal_transformer_nhead": 4, + "temporal_transformer_ff_dim": 512, + "temporal_transformer_dropout": 0.1, "temporal_epochs": 5, "temporal_lr": 0.0005, "sample_batch_size": 4, diff --git a/example/export_samples.py b/example/export_samples.py index 7c11f63..cc54ab8 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real -from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule +from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path @@ -154,9 +154,16 @@ def main(): type5_cols = [c for c in type5_cols if c in cont_cols] model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols] use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) + temporal_backbone = str(cfg.get("temporal_backbone", "gru")) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) + temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64)) + temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True)) + temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2)) + temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4)) + temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512)) + temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1)) backbone_type = str(cfg.get("backbone_type", "gru")) transformer_num_layers = int(cfg.get("transformer_num_layers", 2)) transformer_nhead = int(cfg.get("transformer_nhead", 4)) @@ -193,12 +200,24 @@ def main(): temporal_model = None if use_temporal_stage1: - temporal_model = TemporalGRUGenerator( - input_dim=len(model_cont_cols), - hidden_dim=temporal_hidden_dim, - num_layers=temporal_num_layers, - dropout=temporal_dropout, - ).to(device) + if temporal_backbone == "transformer": + temporal_model = TemporalTransformerGenerator( + input_dim=len(model_cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_transformer_num_layers, + nhead=temporal_transformer_nhead, + ff_dim=temporal_transformer_ff_dim, + dropout=temporal_transformer_dropout, + pos_dim=temporal_pos_dim, + use_pos_embed=temporal_use_pos_embed, + ).to(device) + else: + temporal_model = TemporalGRUGenerator( + input_dim=len(model_cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_num_layers, + dropout=temporal_dropout, + ).to(device) temporal_path = Path(args.model_path).with_name("temporal.pt") if not temporal_path.exists(): raise SystemExit(f"missing temporal model file: {temporal_path}") diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index 1b011a2..82fcffd 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module): return torch.cat(outputs, dim=1) +class TemporalTransformerGenerator(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int = 256, + num_layers: int = 2, + nhead: int = 4, + ff_dim: int = 512, + dropout: float = 0.1, + pos_dim: int = 64, + use_pos_embed: bool = True, + ): + super().__init__() + self.start_token = nn.Parameter(torch.zeros(input_dim)) + self.in_proj = nn.Linear(input_dim, hidden_dim) + self.pos_dim = pos_dim + self.use_pos_embed = use_pos_embed + self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=nhead, + dim_feedforward=ff_dim, + dropout=dropout, + batch_first=True, + activation="gelu", + ) + self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + self.out = nn.Linear(hidden_dim, input_dim) + + def forward_teacher(self, x: torch.Tensor) -> torch.Tensor: + if x.size(1) < 2: + raise ValueError("sequence length must be >= 2 for teacher forcing") + inp = x[:, :-1, :] + out = self._encode(inp) + pred_next = self.out(out) + trend = torch.zeros_like(x) + trend[:, 0, :] = x[:, 0, :] + trend[:, 1:, :] = pred_next + return trend, pred_next + + def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor: + context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device) + outputs = [] + for _ in range(seq_len): + out = self._encode(context) + next_token = self.out(out[:, -1, :]) + outputs.append(next_token.unsqueeze(1)) + context = torch.cat([context, next_token.unsqueeze(1)], dim=1) + return torch.cat(outputs, dim=1) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + feat = self.in_proj(x) + if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0: + pos = self._positional_encoding(x.size(1), self.pos_dim, x.device) + pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1) + feat = feat + pos + mask = self._causal_mask(x.size(1), x.device) + return self.backbone(feat, mask=mask) + + @staticmethod + def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() + + @staticmethod + def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor: + pos = torch.arange(seq_len, device=device).float().unsqueeze(1) + div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim)) + pe = torch.zeros(seq_len, dim, device=device) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return pe + + class HybridDiffusionModel(nn.Module): def __init__( self, diff --git a/example/sample.py b/example/sample.py index 40baa7b..ec963bd 100755 --- a/example/sample.py +++ b/example/sample.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F from data_utils import load_split -from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule +from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent @@ -48,9 +48,16 @@ def main(): use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) + temporal_backbone = str(cfg.get("temporal_backbone", "gru")) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) + temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64)) + temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True)) + temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2)) + temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4)) + temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512)) + temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1)) cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) model_time_dim = int(cfg.get("model_time_dim", 64)) @@ -108,12 +115,24 @@ def main(): temporal_model = None if use_temporal_stage1: - temporal_model = TemporalGRUGenerator( - input_dim=len(cont_cols), - hidden_dim=temporal_hidden_dim, - num_layers=temporal_num_layers, - dropout=temporal_dropout, - ).to(DEVICE) + if temporal_backbone == "transformer": + temporal_model = TemporalTransformerGenerator( + input_dim=len(cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_transformer_num_layers, + nhead=temporal_transformer_nhead, + ff_dim=temporal_transformer_ff_dim, + dropout=temporal_transformer_dropout, + pos_dim=temporal_pos_dim, + use_pos_embed=temporal_use_pos_embed, + ).to(DEVICE) + else: + temporal_model = TemporalGRUGenerator( + input_dim=len(cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_num_layers, + dropout=temporal_dropout, + ).to(DEVICE) temporal_path = BASE_DIR / "results" / "temporal.pt" if not temporal_path.exists(): raise SystemExit(f"missing temporal model file: {temporal_path}") diff --git a/example/train.py b/example/train.py index 0852fbe..25d8706 100755 --- a/example/train.py +++ b/example/train.py @@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, TemporalGRUGenerator, + TemporalTransformerGenerator, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, @@ -66,9 +67,16 @@ DEFAULTS = { "cont_target": "eps", # eps | x0 "cont_clamp_x0": 0.0, "use_temporal_stage1": True, + "temporal_backbone": "gru", "temporal_hidden_dim": 256, "temporal_num_layers": 1, "temporal_dropout": 0.0, + "temporal_pos_dim": 64, + "temporal_use_pos_embed": True, + "temporal_transformer_num_layers": 2, + "temporal_transformer_nhead": 4, + "temporal_transformer_ff_dim": 512, + "temporal_transformer_dropout": 0.1, "temporal_epochs": 2, "temporal_lr": 1e-3, "quantile_loss_weight": 0.0, @@ -226,12 +234,25 @@ def main(): temporal_model = None opt_temporal = None if bool(config.get("use_temporal_stage1", False)): - temporal_model = TemporalGRUGenerator( - input_dim=len(model_cont_cols), - hidden_dim=int(config.get("temporal_hidden_dim", 256)), - num_layers=int(config.get("temporal_num_layers", 1)), - dropout=float(config.get("temporal_dropout", 0.0)), - ).to(device) + temporal_backbone = str(config.get("temporal_backbone", "gru")) + if temporal_backbone == "transformer": + temporal_model = TemporalTransformerGenerator( + input_dim=len(model_cont_cols), + hidden_dim=int(config.get("temporal_hidden_dim", 256)), + num_layers=int(config.get("temporal_transformer_num_layers", 2)), + nhead=int(config.get("temporal_transformer_nhead", 4)), + ff_dim=int(config.get("temporal_transformer_ff_dim", 512)), + dropout=float(config.get("temporal_transformer_dropout", 0.1)), + pos_dim=int(config.get("temporal_pos_dim", 64)), + use_pos_embed=bool(config.get("temporal_use_pos_embed", True)), + ).to(device) + else: + temporal_model = TemporalGRUGenerator( + input_dim=len(model_cont_cols), + hidden_dim=int(config.get("temporal_hidden_dim", 256)), + num_layers=int(config.get("temporal_num_layers", 1)), + dropout=float(config.get("temporal_dropout", 0.0)), + ).to(device) opt_temporal = torch.optim.Adam( temporal_model.parameters(), lr=float(config.get("temporal_lr", config["lr"])),