Use transformer for temporal trend model

This commit is contained in:
Mingzhe Yang
2026-02-04 02:40:57 +08:00
parent 84ac4cd2eb
commit 175fc684e3
6 changed files with 166 additions and 20 deletions

View File

@@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module):
return torch.cat(outputs, dim=1)
class TemporalTransformerGenerator(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
num_layers: int = 2,
nhead: int = 4,
ff_dim: int = 512,
dropout: float = 0.1,
pos_dim: int = 64,
use_pos_embed: bool = True,
):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.in_proj = nn.Linear(input_dim, hidden_dim)
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=nhead,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
activation="gelu",
)
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out = self._encode(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
outputs = []
for _ in range(seq_len):
out = self._encode(context)
next_token = self.out(out[:, -1, :])
outputs.append(next_token.unsqueeze(1))
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
return torch.cat(outputs, dim=1)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
feat = self.in_proj(x)
if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0:
pos = self._positional_encoding(x.size(1), self.pos_dim, x.device)
pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1)
feat = feat + pos
mask = self._causal_mask(x.size(1), x.device)
return self.backbone(feat, mask=mask)
@staticmethod
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
@staticmethod
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe
class HybridDiffusionModel(nn.Module):
def __init__(
self,