Use transformer for temporal trend model
This commit is contained in:
@@ -57,9 +57,16 @@
|
|||||||
"type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"],
|
"type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"],
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"use_temporal_stage1": true,
|
"use_temporal_stage1": true,
|
||||||
|
"temporal_backbone": "transformer",
|
||||||
"temporal_hidden_dim": 384,
|
"temporal_hidden_dim": 384,
|
||||||
"temporal_num_layers": 1,
|
"temporal_num_layers": 1,
|
||||||
"temporal_dropout": 0.0,
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_pos_dim": 64,
|
||||||
|
"temporal_use_pos_embed": true,
|
||||||
|
"temporal_transformer_num_layers": 2,
|
||||||
|
"temporal_transformer_nhead": 4,
|
||||||
|
"temporal_transformer_ff_dim": 512,
|
||||||
|
"temporal_transformer_dropout": 0.1,
|
||||||
"temporal_epochs": 3,
|
"temporal_epochs": 3,
|
||||||
"temporal_lr": 0.001,
|
"temporal_lr": 0.001,
|
||||||
"quantile_loss_weight": 0.2,
|
"quantile_loss_weight": 0.2,
|
||||||
|
|||||||
@@ -51,9 +51,16 @@
|
|||||||
"full_stats": true,
|
"full_stats": true,
|
||||||
"shuffle_buffer": 1024,
|
"shuffle_buffer": 1024,
|
||||||
"use_temporal_stage1": true,
|
"use_temporal_stage1": true,
|
||||||
|
"temporal_backbone": "transformer",
|
||||||
"temporal_hidden_dim": 512,
|
"temporal_hidden_dim": 512,
|
||||||
"temporal_num_layers": 2,
|
"temporal_num_layers": 2,
|
||||||
"temporal_dropout": 0.0,
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_pos_dim": 64,
|
||||||
|
"temporal_use_pos_embed": true,
|
||||||
|
"temporal_transformer_num_layers": 2,
|
||||||
|
"temporal_transformer_nhead": 4,
|
||||||
|
"temporal_transformer_ff_dim": 512,
|
||||||
|
"temporal_transformer_dropout": 0.1,
|
||||||
"temporal_epochs": 5,
|
"temporal_epochs": 5,
|
||||||
"temporal_lr": 0.0005,
|
"temporal_lr": 0.0005,
|
||||||
"sample_batch_size": 4,
|
"sample_batch_size": 4,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
||||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
|
|
||||||
@@ -154,9 +154,16 @@ def main():
|
|||||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||||
|
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||||
|
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||||
|
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||||
|
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||||
|
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||||
@@ -193,12 +200,24 @@ def main():
|
|||||||
|
|
||||||
temporal_model = None
|
temporal_model = None
|
||||||
if use_temporal_stage1:
|
if use_temporal_stage1:
|
||||||
temporal_model = TemporalGRUGenerator(
|
if temporal_backbone == "transformer":
|
||||||
input_dim=len(model_cont_cols),
|
temporal_model = TemporalTransformerGenerator(
|
||||||
hidden_dim=temporal_hidden_dim,
|
input_dim=len(model_cont_cols),
|
||||||
num_layers=temporal_num_layers,
|
hidden_dim=temporal_hidden_dim,
|
||||||
dropout=temporal_dropout,
|
num_layers=temporal_transformer_num_layers,
|
||||||
).to(device)
|
nhead=temporal_transformer_nhead,
|
||||||
|
ff_dim=temporal_transformer_ff_dim,
|
||||||
|
dropout=temporal_transformer_dropout,
|
||||||
|
pos_dim=temporal_pos_dim,
|
||||||
|
use_pos_embed=temporal_use_pos_embed,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(device)
|
||||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||||
if not temporal_path.exists():
|
if not temporal_path.exists():
|
||||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
|||||||
@@ -106,6 +106,79 @@ class TemporalGRUGenerator(nn.Module):
|
|||||||
return torch.cat(outputs, dim=1)
|
return torch.cat(outputs, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalTransformerGenerator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
num_layers: int = 2,
|
||||||
|
nhead: int = 4,
|
||||||
|
ff_dim: int = 512,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
pos_dim: int = 64,
|
||||||
|
use_pos_embed: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||||
|
self.in_proj = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.pos_dim = pos_dim
|
||||||
|
self.use_pos_embed = use_pos_embed
|
||||||
|
self.pos_proj = nn.Linear(pos_dim, hidden_dim) if use_pos_embed and pos_dim > 0 else None
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=hidden_dim,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=ff_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
activation="gelu",
|
||||||
|
)
|
||||||
|
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||||
|
self.out = nn.Linear(hidden_dim, input_dim)
|
||||||
|
|
||||||
|
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if x.size(1) < 2:
|
||||||
|
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||||
|
inp = x[:, :-1, :]
|
||||||
|
out = self._encode(inp)
|
||||||
|
pred_next = self.out(out)
|
||||||
|
trend = torch.zeros_like(x)
|
||||||
|
trend[:, 0, :] = x[:, 0, :]
|
||||||
|
trend[:, 1:, :] = pred_next
|
||||||
|
return trend, pred_next
|
||||||
|
|
||||||
|
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||||
|
context = self.start_token.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1).to(device)
|
||||||
|
outputs = []
|
||||||
|
for _ in range(seq_len):
|
||||||
|
out = self._encode(context)
|
||||||
|
next_token = self.out(out[:, -1, :])
|
||||||
|
outputs.append(next_token.unsqueeze(1))
|
||||||
|
context = torch.cat([context, next_token.unsqueeze(1)], dim=1)
|
||||||
|
return torch.cat(outputs, dim=1)
|
||||||
|
|
||||||
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
feat = self.in_proj(x)
|
||||||
|
if self.pos_proj is not None and self.use_pos_embed and self.pos_dim > 0:
|
||||||
|
pos = self._positional_encoding(x.size(1), self.pos_dim, x.device)
|
||||||
|
pos = self.pos_proj(pos).unsqueeze(0).expand(x.size(0), -1, -1)
|
||||||
|
feat = feat + pos
|
||||||
|
mask = self._causal_mask(x.size(1), x.device)
|
||||||
|
return self.backbone(feat, mask=mask)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
||||||
|
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
||||||
|
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
||||||
|
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
||||||
|
pe = torch.zeros(seq_len, dim, device=device)
|
||||||
|
pe[:, 0::2] = torch.sin(pos * div)
|
||||||
|
pe[:, 1::2] = torch.cos(pos * div)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
class HybridDiffusionModel(nn.Module):
|
class HybridDiffusionModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
@@ -48,9 +48,16 @@ def main():
|
|||||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||||
|
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||||
|
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||||
|
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||||
|
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||||
|
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||||
cont_target = str(cfg.get("cont_target", "eps"))
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||||
@@ -108,12 +115,24 @@ def main():
|
|||||||
|
|
||||||
temporal_model = None
|
temporal_model = None
|
||||||
if use_temporal_stage1:
|
if use_temporal_stage1:
|
||||||
temporal_model = TemporalGRUGenerator(
|
if temporal_backbone == "transformer":
|
||||||
input_dim=len(cont_cols),
|
temporal_model = TemporalTransformerGenerator(
|
||||||
hidden_dim=temporal_hidden_dim,
|
input_dim=len(cont_cols),
|
||||||
num_layers=temporal_num_layers,
|
hidden_dim=temporal_hidden_dim,
|
||||||
dropout=temporal_dropout,
|
num_layers=temporal_transformer_num_layers,
|
||||||
).to(DEVICE)
|
nhead=temporal_transformer_nhead,
|
||||||
|
ff_dim=temporal_transformer_ff_dim,
|
||||||
|
dropout=temporal_transformer_dropout,
|
||||||
|
pos_dim=temporal_pos_dim,
|
||||||
|
use_pos_embed=temporal_use_pos_embed,
|
||||||
|
).to(DEVICE)
|
||||||
|
else:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(DEVICE)
|
||||||
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||||
if not temporal_path.exists():
|
if not temporal_path.exists():
|
||||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from data_utils import load_split, windowed_batches
|
|||||||
from hybrid_diffusion import (
|
from hybrid_diffusion import (
|
||||||
HybridDiffusionModel,
|
HybridDiffusionModel,
|
||||||
TemporalGRUGenerator,
|
TemporalGRUGenerator,
|
||||||
|
TemporalTransformerGenerator,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
q_sample_continuous,
|
q_sample_continuous,
|
||||||
q_sample_discrete,
|
q_sample_discrete,
|
||||||
@@ -66,9 +67,16 @@ DEFAULTS = {
|
|||||||
"cont_target": "eps", # eps | x0
|
"cont_target": "eps", # eps | x0
|
||||||
"cont_clamp_x0": 0.0,
|
"cont_clamp_x0": 0.0,
|
||||||
"use_temporal_stage1": True,
|
"use_temporal_stage1": True,
|
||||||
|
"temporal_backbone": "gru",
|
||||||
"temporal_hidden_dim": 256,
|
"temporal_hidden_dim": 256,
|
||||||
"temporal_num_layers": 1,
|
"temporal_num_layers": 1,
|
||||||
"temporal_dropout": 0.0,
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_pos_dim": 64,
|
||||||
|
"temporal_use_pos_embed": True,
|
||||||
|
"temporal_transformer_num_layers": 2,
|
||||||
|
"temporal_transformer_nhead": 4,
|
||||||
|
"temporal_transformer_ff_dim": 512,
|
||||||
|
"temporal_transformer_dropout": 0.1,
|
||||||
"temporal_epochs": 2,
|
"temporal_epochs": 2,
|
||||||
"temporal_lr": 1e-3,
|
"temporal_lr": 1e-3,
|
||||||
"quantile_loss_weight": 0.0,
|
"quantile_loss_weight": 0.0,
|
||||||
@@ -226,12 +234,25 @@ def main():
|
|||||||
temporal_model = None
|
temporal_model = None
|
||||||
opt_temporal = None
|
opt_temporal = None
|
||||||
if bool(config.get("use_temporal_stage1", False)):
|
if bool(config.get("use_temporal_stage1", False)):
|
||||||
temporal_model = TemporalGRUGenerator(
|
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||||
input_dim=len(model_cont_cols),
|
if temporal_backbone == "transformer":
|
||||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
temporal_model = TemporalTransformerGenerator(
|
||||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
input_dim=len(model_cont_cols),
|
||||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
).to(device)
|
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
||||||
|
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
||||||
|
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
||||||
|
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||||
|
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||||
|
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||||
|
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||||
|
).to(device)
|
||||||
opt_temporal = torch.optim.Adam(
|
opt_temporal = torch.optim.Adam(
|
||||||
temporal_model.parameters(),
|
temporal_model.parameters(),
|
||||||
lr=float(config.get("temporal_lr", config["lr"])),
|
lr=float(config.get("temporal_lr", config["lr"])),
|
||||||
|
|||||||
Reference in New Issue
Block a user