连续型特征在时许相关性上的不足

This commit is contained in:
2026-01-23 15:06:52 +08:00
parent 0d17be9a1c
commit ff12324560
12 changed files with 1212 additions and 68 deletions

View File

@@ -35,11 +35,14 @@ def q_sample_discrete(
t: torch.Tensor,
mask_tokens: torch.Tensor,
max_t: int,
mask_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly mask discrete tokens with a cosine schedule over t."""
bsz = x0.size(0)
# cosine schedule: p(0)=0, p(max_t)=1
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
if mask_scale != 1.0:
p = torch.clamp(p * mask_scale, 0.0, 1.0)
p = p.view(bsz, 1, 1)
mask = torch.rand_like(x0.float()) < p
x_masked = x0.clone()
@@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module):
disc_vocab_sizes: List[int],
time_dim: int = 64,
hidden_dim: int = 256,
num_layers: int = 1,
dropout: float = 0.0,
ff_mult: int = 2,
pos_dim: int = 64,
use_pos_embed: bool = True,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
@@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module):
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.use_tanh_eps = use_tanh_eps
self.eps_scale = eps_scale
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
@@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim)
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.backbone = nn.GRU(
hidden_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.post_norm = nn.LayerNorm(hidden_dim)
self.post_ff = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * ff_mult),
nn.GELU(),
nn.Linear(hidden_dim * ff_mult, hidden_dim),
)
self.cont_head = nn.Linear(hidden_dim, cont_dim)
self.disc_heads = nn.ModuleList([
@@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module):
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
time_emb = self.time_embed(t)
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
pos_emb = None
if self.use_pos_embed and self.pos_dim > 0:
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
disc_embs = []
for i, emb in enumerate(self.disc_embeds):
@@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module):
cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb]
if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None:
parts.append(cond_feat)
feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat)
out, _ = self.backbone(feat)
out = self.post_norm(out)
out = out + self.post_ff(out)
eps_pred = self.cont_head(out)
if self.use_tanh_eps:
eps_pred = torch.tanh(eps_pred) * self.eps_scale
logits = [head(out) for head in self.disc_heads]
return eps_pred, logits
@staticmethod
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe