transformer
This commit is contained in:
@@ -118,6 +118,11 @@ class HybridDiffusionModel(nn.Module):
|
||||
ff_mult: int = 2,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
backbone_type: str = "gru", # gru | transformer
|
||||
transformer_num_layers: int = 4,
|
||||
transformer_nhead: int = 8,
|
||||
transformer_ff_dim: int = 2048,
|
||||
transformer_dropout: float = 0.1,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -132,6 +137,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.eps_scale = eps_scale
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.backbone_type = backbone_type
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
@@ -149,13 +155,24 @@ class HybridDiffusionModel(nn.Module):
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
if backbone_type == "transformer":
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=transformer_nhead,
|
||||
dim_feedforward=transformer_ff_dim,
|
||||
dropout=transformer_dropout,
|
||||
batch_first=True,
|
||||
activation="gelu",
|
||||
)
|
||||
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=transformer_num_layers)
|
||||
else:
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.post_norm = nn.LayerNorm(hidden_dim)
|
||||
self.post_ff = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim * ff_mult),
|
||||
@@ -197,7 +214,10 @@ class HybridDiffusionModel(nn.Module):
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
out, _ = self.backbone(feat)
|
||||
if self.backbone_type == "transformer":
|
||||
out = self.backbone(feat)
|
||||
else:
|
||||
out, _ = self.backbone(feat)
|
||||
out = self.post_norm(out)
|
||||
out = out + self.post_ff(out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user