update新结构
This commit is contained in:
@@ -84,6 +84,46 @@ class FeatureGraphMixer(nn.Module):
|
||||
return x + mixed
|
||||
|
||||
|
||||
class TemporalGRUGenerator(nn.Module):
|
||||
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.gru = nn.GRU(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
out, _ = self.gru(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
trend[:, 0, :] = x[:, 0, :]
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
"""Autoregressively generate a backbone sequence."""
|
||||
h = None
|
||||
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out, h = self.gru(prev.unsqueeze(1), h)
|
||||
nxt = self.out(out.squeeze(1))
|
||||
outputs.append(nxt.unsqueeze(1))
|
||||
prev = nxt
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
|
||||
class HybridDiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user