update
This commit is contained in:
@@ -69,11 +69,13 @@ class SinusoidalTimeEmbedding(nn.Module):
|
||||
class TemporalGRUGenerator(nn.Module):
|
||||
"""Stage-1 temporal generator for sequence trend."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
|
||||
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0, cond_dim: int = 0):
|
||||
super().__init__()
|
||||
self.input_dim = int(input_dim)
|
||||
self.cond_dim = int(cond_dim)
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.gru = nn.GRU(
|
||||
input_dim,
|
||||
input_dim + self.cond_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
@@ -81,11 +83,16 @@ class TemporalGRUGenerator(nn.Module):
|
||||
)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_teacher(self, x: torch.Tensor, cond_cont: torch.Tensor | None = None) -> torch.Tensor:
|
||||
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
if self.cond_dim > 0:
|
||||
if cond_cont is None:
|
||||
cond_cont = torch.zeros(x.size(0), x.size(1), self.cond_dim, device=x.device, dtype=x.dtype)
|
||||
inp = torch.cat([x[:, :-1, :], cond_cont[:, :-1, :]], dim=-1)
|
||||
else:
|
||||
inp = x[:, :-1, :]
|
||||
out, _ = self.gru(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
@@ -93,13 +100,35 @@ class TemporalGRUGenerator(nn.Module):
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
def generate(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
device: torch.device,
|
||||
cond_cont: torch.Tensor | None = None,
|
||||
start_x: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Autoregressively generate a backbone sequence."""
|
||||
h = None
|
||||
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
if start_x is not None:
|
||||
if start_x.dim() == 3 and start_x.size(1) == 1:
|
||||
start_x = start_x[:, 0, :]
|
||||
prev = start_x.to(device)
|
||||
else:
|
||||
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
if self.cond_dim > 0:
|
||||
if cond_cont is None:
|
||||
cond_cont = torch.zeros(batch_size, seq_len, self.cond_dim, device=device, dtype=prev.dtype)
|
||||
else:
|
||||
cond_cont = cond_cont.to(device)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out, h = self.gru(prev.unsqueeze(1), h)
|
||||
for t in range(seq_len):
|
||||
if self.cond_dim > 0:
|
||||
ct = cond_cont[:, t, :] if t < cond_cont.size(1) else torch.zeros(batch_size, self.cond_dim, device=device, dtype=prev.dtype)
|
||||
step_inp = torch.cat([prev, ct], dim=-1)
|
||||
else:
|
||||
step_inp = prev
|
||||
out, h = self.gru(step_inp.unsqueeze(1), h)
|
||||
nxt = self.out(out.squeeze(1))
|
||||
outputs.append(nxt.unsqueeze(1))
|
||||
prev = nxt
|
||||
@@ -117,10 +146,13 @@ class TemporalTransformerGenerator(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
cond_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_dim = int(input_dim)
|
||||
self.cond_dim = int(cond_dim)
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.in_proj = nn.Linear(input_dim, hidden_dim)
|
||||
self.in_proj = nn.Linear(input_dim + self.cond_dim, hidden_dim)
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
|
||||
@@ -135,10 +167,15 @@ class TemporalTransformerGenerator(nn.Module):
|
||||
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_teacher(self, x: torch.Tensor, cond_cont: torch.Tensor | None = None) -> torch.Tensor:
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
if self.cond_dim > 0:
|
||||
if cond_cont is None:
|
||||
cond_cont = torch.zeros(x.size(0), x.size(1), self.cond_dim, device=x.device, dtype=x.dtype)
|
||||
inp = torch.cat([x[:, :-1, :], cond_cont[:, :-1, :]], dim=-1)
|
||||
else:
|
||||
inp = x[:, :-1, :]
|
||||
out = self._encode(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
@@ -146,14 +183,40 @@ class TemporalTransformerGenerator(nn.Module):
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
|
||||
def generate(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
device: torch.device,
|
||||
cond_cont: torch.Tensor | None = None,
|
||||
start_x: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if start_x is not None:
|
||||
if start_x.dim() == 2:
|
||||
context_x = start_x.unsqueeze(1).to(device)
|
||||
else:
|
||||
context_x = start_x.to(device)
|
||||
else:
|
||||
context_x = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
|
||||
if self.cond_dim > 0:
|
||||
if cond_cont is None:
|
||||
cond_cont = torch.zeros(batch_size, seq_len, self.cond_dim, device=device, dtype=context_x.dtype)
|
||||
else:
|
||||
cond_cont = cond_cont.to(device)
|
||||
if cond_cont.size(1) < seq_len:
|
||||
pad = torch.zeros(batch_size, seq_len - cond_cont.size(1), self.cond_dim, device=device, dtype=cond_cont.dtype)
|
||||
cond_cont = torch.cat([cond_cont, pad], dim=1)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out = self._encode(context)
|
||||
if self.cond_dim > 0:
|
||||
cond_ctx = cond_cont[:, : context_x.size(1), :]
|
||||
context_in = torch.cat([context_x, cond_ctx], dim=-1)
|
||||
else:
|
||||
context_in = context_x
|
||||
out = self._encode(context_in)
|
||||
next_token = self.out(out[:, -1, :])
|
||||
outputs.append(next_token.unsqueeze(1))
|
||||
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
|
||||
context_x = torch.cat([context_x, next_token.unsqueeze(1)], dim=1)
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user