优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
@@ -123,6 +123,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
transformer_nhead: int = 8,
|
||||
transformer_ff_dim: int = 2048,
|
||||
transformer_dropout: float = 0.1,
|
||||
cond_cont_dim: int = 0,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -144,6 +145,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.cond_embed = None
|
||||
if cond_vocab_size and cond_vocab_size > 0:
|
||||
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
|
||||
self.cond_cont_dim = cond_cont_dim
|
||||
|
||||
self.disc_embeds = nn.ModuleList([
|
||||
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
||||
@@ -154,6 +156,8 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
if self.cond_cont_dim and self.cond_cont_dim > 0:
|
||||
in_dim += self.cond_cont_dim
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
if backbone_type == "transformer":
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
@@ -186,7 +190,14 @@ class HybridDiffusionModel(nn.Module):
|
||||
for vocab_size in disc_vocab_sizes
|
||||
])
|
||||
|
||||
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
|
||||
def forward(
|
||||
self,
|
||||
x_cont: torch.Tensor,
|
||||
x_disc: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor = None,
|
||||
cond_cont: torch.Tensor = None,
|
||||
):
|
||||
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||
time_emb = self.time_embed(t)
|
||||
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
@@ -211,6 +222,10 @@ class HybridDiffusionModel(nn.Module):
|
||||
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
||||
if cond_feat is not None:
|
||||
parts.append(cond_feat)
|
||||
if self.cond_cont_dim and self.cond_cont_dim > 0:
|
||||
if cond_cont is None:
|
||||
raise ValueError("cond_cont is required when cond_cont_dim > 0")
|
||||
parts.append(cond_cont)
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user