Use transformer for temporal trend model
This commit is contained in:
@@ -57,9 +57,16 @@
|
||||
"type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"],
|
||||
"shuffle_buffer": 256,
|
||||
"use_temporal_stage1": true,
|
||||
"temporal_backbone": "transformer",
|
||||
"temporal_hidden_dim": 384,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_pos_dim": 64,
|
||||
"temporal_use_pos_embed": true,
|
||||
"temporal_transformer_num_layers": 2,
|
||||
"temporal_transformer_nhead": 4,
|
||||
"temporal_transformer_ff_dim": 512,
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 3,
|
||||
"temporal_lr": 0.001,
|
||||
"quantile_loss_weight": 0.2,
|
||||
|
||||
@@ -51,9 +51,16 @@
|
||||
"full_stats": true,
|
||||
"shuffle_buffer": 1024,
|
||||
"use_temporal_stage1": true,
|
||||
"temporal_backbone": "transformer",
|
||||
"temporal_hidden_dim": 512,
|
||||
"temporal_num_layers": 2,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_pos_dim": 64,
|
||||
"temporal_use_pos_embed": true,
|
||||
"temporal_transformer_num_layers": 2,
|
||||
"temporal_transformer_nhead": 4,
|
||||
"temporal_transformer_ff_dim": 512,
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 5,
|
||||
"temporal_lr": 0.0005,
|
||||
"sample_batch_size": 4,
|
||||
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||
|
||||
|
||||
@@ -154,9 +154,16 @@ def main():
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
@@ -193,6 +200,18 @@ def main():
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_transformer_num_layers,
|
||||
nhead=temporal_transformer_nhead,
|
||||
ff_dim=temporal_transformer_ff_dim,
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
|
||||
@@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module):
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
|
||||
class TemporalTransformerGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int = 256,
|
||||
num_layers: int = 2,
|
||||
nhead: int = 4,
|
||||
ff_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.in_proj = nn.Linear(input_dim, hidden_dim)
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=nhead,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
activation="gelu",
|
||||
)
|
||||
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
out = self._encode(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
trend[:, 0, :] = x[:, 0, :]
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out = self._encode(context)
|
||||
next_token = self.out(out[:, -1, :])
|
||||
outputs.append(next_token.unsqueeze(1))
|
||||
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
feat = self.in_proj(x)
|
||||
if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0:
|
||||
pos = self._positional_encoding(x.size(1), self.pos_dim, x.device)
|
||||
pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1)
|
||||
feat = feat + pos
|
||||
mask = self._causal_mask(x.size(1), x.device)
|
||||
return self.backbone(feat, mask=mask)
|
||||
|
||||
@staticmethod
|
||||
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
|
||||
|
||||
@staticmethod
|
||||
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
||||
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
||||
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
||||
pe = torch.zeros(seq_len, dim, device=device)
|
||||
pe[:, 0::2] = torch.sin(pos * div)
|
||||
pe[:, 1::2] = torch.cos(pos * div)
|
||||
return pe
|
||||
|
||||
|
||||
class HybridDiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -48,9 +48,16 @@ def main():
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||
cont_target = str(cfg.get("cont_target", "eps"))
|
||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||
@@ -108,6 +115,18 @@ def main():
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_transformer_num_layers,
|
||||
nhead=temporal_transformer_nhead,
|
||||
ff_dim=temporal_transformer_ff_dim,
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
).to(DEVICE)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
|
||||
@@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
TemporalTransformerGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
@@ -66,9 +67,16 @@ DEFAULTS = {
|
||||
"cont_target": "eps", # eps | x0
|
||||
"cont_clamp_x0": 0.0,
|
||||
"use_temporal_stage1": True,
|
||||
"temporal_backbone": "gru",
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_pos_dim": 64,
|
||||
"temporal_use_pos_embed": True,
|
||||
"temporal_transformer_num_layers": 2,
|
||||
"temporal_transformer_nhead": 4,
|
||||
"temporal_transformer_ff_dim": 512,
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 2,
|
||||
"temporal_lr": 1e-3,
|
||||
"quantile_loss_weight": 0.0,
|
||||
@@ -226,6 +234,19 @@ def main():
|
||||
temporal_model = None
|
||||
opt_temporal = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
||||
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
||||
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
||||
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
|
||||
Reference in New Issue
Block a user