178 lines
6.3 KiB
Python
Executable File
178 lines
6.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Hybrid diffusion scaffold for continuous + discrete HAI features.
|
|
|
|
Continuous: Gaussian diffusion (DDPM-style).
|
|
Discrete: mask-based diffusion (predict original token).
|
|
"""
|
|
|
|
import math
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor:
|
|
steps = timesteps + 1
|
|
x = torch.linspace(0, timesteps, steps)
|
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
return torch.clip(betas, 1e-5, 0.999)
|
|
|
|
|
|
def q_sample_continuous(x0: torch.Tensor, t: torch.Tensor, alphas_cumprod: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Add Gaussian noise to continuous features at timestep t."""
|
|
noise = torch.randn_like(x0)
|
|
a_bar = alphas_cumprod[t].view(-1, 1, 1)
|
|
xt = torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise
|
|
return xt, noise
|
|
|
|
|
|
def q_sample_discrete(
|
|
x0: torch.Tensor,
|
|
t: torch.Tensor,
|
|
mask_tokens: torch.Tensor,
|
|
max_t: int,
|
|
mask_scale: float = 1.0,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Randomly mask discrete tokens with a cosine schedule over t."""
|
|
bsz = x0.size(0)
|
|
# cosine schedule: p(0)=0, p(max_t)=1
|
|
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
|
|
if mask_scale != 1.0:
|
|
p = torch.clamp(p * mask_scale, 0.0, 1.0)
|
|
p = p.view(bsz, 1, 1)
|
|
mask = torch.rand_like(x0.float()) < p
|
|
x_masked = x0.clone()
|
|
for i in range(x0.size(2)):
|
|
x_masked[:, :, i][mask[:, :, i]] = mask_tokens[i]
|
|
return x_masked, mask
|
|
|
|
|
|
class SinusoidalTimeEmbedding(nn.Module):
|
|
def __init__(self, dim: int):
|
|
super().__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
half = self.dim // 2
|
|
freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half)
|
|
args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
|
|
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
|
|
if self.dim % 2 == 1:
|
|
emb = F.pad(emb, (0, 1))
|
|
return emb
|
|
|
|
|
|
class HybridDiffusionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
cont_dim: int,
|
|
disc_vocab_sizes: List[int],
|
|
time_dim: int = 64,
|
|
hidden_dim: int = 256,
|
|
num_layers: int = 1,
|
|
dropout: float = 0.0,
|
|
ff_mult: int = 2,
|
|
pos_dim: int = 64,
|
|
use_pos_embed: bool = True,
|
|
cond_vocab_size: int = 0,
|
|
cond_dim: int = 32,
|
|
use_tanh_eps: bool = False,
|
|
eps_scale: float = 1.0,
|
|
):
|
|
super().__init__()
|
|
self.cont_dim = cont_dim
|
|
self.disc_vocab_sizes = disc_vocab_sizes
|
|
|
|
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
|
self.use_tanh_eps = use_tanh_eps
|
|
self.eps_scale = eps_scale
|
|
self.pos_dim = pos_dim
|
|
self.use_pos_embed = use_pos_embed
|
|
|
|
self.cond_vocab_size = cond_vocab_size
|
|
self.cond_dim = cond_dim
|
|
self.cond_embed = None
|
|
if cond_vocab_size and cond_vocab_size > 0:
|
|
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
|
|
|
|
self.disc_embeds = nn.ModuleList([
|
|
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
|
for vocab_size in disc_vocab_sizes
|
|
])
|
|
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
|
|
|
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
|
pos_dim = pos_dim if use_pos_embed else 0
|
|
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
|
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
|
self.backbone = nn.GRU(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
num_layers=num_layers,
|
|
dropout=dropout if num_layers > 1 else 0.0,
|
|
batch_first=True,
|
|
)
|
|
self.post_norm = nn.LayerNorm(hidden_dim)
|
|
self.post_ff = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim * ff_mult),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim * ff_mult, hidden_dim),
|
|
)
|
|
|
|
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
|
self.disc_heads = nn.ModuleList([
|
|
nn.Linear(hidden_dim, vocab_size)
|
|
for vocab_size in disc_vocab_sizes
|
|
])
|
|
|
|
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
|
|
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
|
time_emb = self.time_embed(t)
|
|
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
|
pos_emb = None
|
|
if self.use_pos_embed and self.pos_dim > 0:
|
|
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
|
|
|
|
disc_embs = []
|
|
for i, emb in enumerate(self.disc_embeds):
|
|
disc_embs.append(emb(x_disc[:, :, i]))
|
|
disc_feat = torch.cat(disc_embs, dim=-1)
|
|
|
|
cond_feat = None
|
|
if self.cond_embed is not None:
|
|
if cond is None:
|
|
raise ValueError("cond is required when cond_vocab_size > 0")
|
|
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
|
|
|
cont_feat = self.cont_proj(x_cont)
|
|
parts = [cont_feat, disc_feat, time_emb]
|
|
if pos_emb is not None:
|
|
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
|
if cond_feat is not None:
|
|
parts.append(cond_feat)
|
|
feat = torch.cat(parts, dim=-1)
|
|
feat = self.in_proj(feat)
|
|
|
|
out, _ = self.backbone(feat)
|
|
out = self.post_norm(out)
|
|
out = out + self.post_ff(out)
|
|
|
|
eps_pred = self.cont_head(out)
|
|
if self.use_tanh_eps:
|
|
eps_pred = torch.tanh(eps_pred) * self.eps_scale
|
|
logits = [head(out) for head in self.disc_heads]
|
|
return eps_pred, logits
|
|
|
|
@staticmethod
|
|
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
|
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
|
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
|
pe = torch.zeros(seq_len, dim, device=device)
|
|
pe[:, 0::2] = torch.sin(pos * div)
|
|
pe[:, 1::2] = torch.cos(pos * div)
|
|
return pe
|