Use transformer for temporal trend model
This commit is contained in:
@@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module):
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
|
||||
class TemporalTransformerGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int = 256,
|
||||
num_layers: int = 2,
|
||||
nhead: int = 4,
|
||||
ff_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.in_proj = nn.Linear(input_dim, hidden_dim)
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=nhead,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
activation="gelu",
|
||||
)
|
||||
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
out = self._encode(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
trend[:, 0, :] = x[:, 0, :]
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out = self._encode(context)
|
||||
next_token = self.out(out[:, -1, :])
|
||||
outputs.append(next_token.unsqueeze(1))
|
||||
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
feat = self.in_proj(x)
|
||||
if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0:
|
||||
pos = self._positional_encoding(x.size(1), self.pos_dim, x.device)
|
||||
pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1)
|
||||
feat = feat + pos
|
||||
mask = self._causal_mask(x.size(1), x.device)
|
||||
return self.backbone(feat, mask=mask)
|
||||
|
||||
@staticmethod
|
||||
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
|
||||
|
||||
@staticmethod
|
||||
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
||||
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
||||
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
||||
pe = torch.zeros(seq_len, dim, device=device)
|
||||
pe[:, 0::2] = torch.sin(pos * div)
|
||||
pe[:, 1::2] = torch.cos(pos * div)
|
||||
return pe
|
||||
|
||||
|
||||
class HybridDiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user