This commit is contained in:
2026-01-22 20:42:10 +08:00
parent f37a8ce179
commit 382c756dfe
10 changed files with 310 additions and 55 deletions

View File

@@ -36,9 +36,10 @@ def q_sample_discrete(
mask_tokens: torch.Tensor,
max_t: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly mask discrete tokens with a linear schedule over t."""
"""Randomly mask discrete tokens with a cosine schedule over t."""
bsz = x0.size(0)
p = t.float() / float(max_t)
# cosine schedule: p(0)=0, p(max_t)=1
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
p = p.view(bsz, 1, 1)
mask = torch.rand_like(x0.float()) < p
x_masked = x0.clone()
@@ -69,12 +70,24 @@ class HybridDiffusionModel(nn.Module):
disc_vocab_sizes: List[int],
time_dim: int = 64,
hidden_dim: int = 256,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
eps_scale: float = 1.0,
):
super().__init__()
self.cont_dim = cont_dim
self.disc_vocab_sizes = disc_vocab_sizes
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.use_tanh_eps = use_tanh_eps
self.eps_scale = eps_scale
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
self.cond_embed = None
if cond_vocab_size and cond_vocab_size > 0:
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
self.disc_embeds = nn.ModuleList([
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
@@ -83,7 +96,8 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim)
self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim)
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.cont_head = nn.Linear(hidden_dim, cont_dim)
@@ -92,7 +106,7 @@ class HybridDiffusionModel(nn.Module):
for vocab_size in disc_vocab_sizes
])
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor):
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
time_emb = self.time_embed(t)
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
@@ -102,12 +116,23 @@ class HybridDiffusionModel(nn.Module):
disc_embs.append(emb(x_disc[:, :, i]))
disc_feat = torch.cat(disc_embs, dim=-1)
cond_feat = None
if self.cond_embed is not None:
if cond is None:
raise ValueError("cond is required when cond_vocab_size > 0")
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
cont_feat = self.cont_proj(x_cont)
feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1)
parts = [cont_feat, disc_feat, time_emb]
if cond_feat is not None:
parts.append(cond_feat)
feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat)
out, _ = self.backbone(feat)
eps_pred = self.cont_head(out)
if self.use_tanh_eps:
eps_pred = torch.tanh(eps_pred) * self.eps_scale
logits = [head(out) for head in self.disc_heads]
return eps_pred, logits