transformer

This commit is contained in:
2026-01-27 00:41:42 +08:00
parent 65391910a2
commit 334db7082b
12 changed files with 175 additions and 11 deletions

View File

@@ -32,6 +32,11 @@
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"backbone_type": "transformer",
"transformer_num_layers": 2,
"transformer_nhead": 4,
"transformer_ff_dim": 512,
"transformer_dropout": 0.1,
"disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6,

View File

@@ -32,6 +32,11 @@
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"backbone_type": "transformer",
"transformer_num_layers": 2,
"transformer_nhead": 4,
"transformer_ff_dim": 512,
"transformer_dropout": 0.1,
"disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6,

View File

@@ -32,6 +32,11 @@
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"backbone_type": "transformer",
"transformer_num_layers": 2,
"transformer_nhead": 4,
"transformer_ff_dim": 512,
"transformer_dropout": 0.1,
"disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6,

View File

@@ -144,6 +144,11 @@ def main():
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
@@ -155,6 +160,11 @@ def main():
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),

View File

@@ -118,6 +118,11 @@ class HybridDiffusionModel(nn.Module):
ff_mult: int = 2,
pos_dim: int = 64,
use_pos_embed: bool = True,
backbone_type: str = "gru", # gru | transformer
transformer_num_layers: int = 4,
transformer_nhead: int = 8,
transformer_ff_dim: int = 2048,
transformer_dropout: float = 0.1,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
@@ -132,6 +137,7 @@ class HybridDiffusionModel(nn.Module):
self.eps_scale = eps_scale
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.backbone_type = backbone_type
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
@@ -149,13 +155,24 @@ class HybridDiffusionModel(nn.Module):
pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(
hidden_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
if backbone_type == "transformer":
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=transformer_nhead,
dim_feedforward=transformer_ff_dim,
dropout=transformer_dropout,
batch_first=True,
activation="gelu",
)
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=transformer_num_layers)
else:
self.backbone = nn.GRU(
hidden_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.post_norm = nn.LayerNorm(hidden_dim)
self.post_ff = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * ff_mult),
@@ -197,7 +214,10 @@ class HybridDiffusionModel(nn.Module):
feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat)
out, _ = self.backbone(feat)
if self.backbone_type == "transformer":
out = self.backbone(feat)
else:
out, _ = self.backbone(feat)
out = self.post_norm(out)
out = out + self.post_ff(out)

View File

@@ -60,6 +60,11 @@ def main():
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -87,6 +92,11 @@ def main():
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,

View File

@@ -200,6 +200,11 @@ def main():
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
backbone_type=str(config.get("backbone_type", "gru")),
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
transformer_nhead=int(config.get("transformer_nhead", 8)),
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),