#!/usr/bin/env python3 """Hybrid diffusion scaffold for continuous + discrete HAI features. Continuous: Gaussian diffusion (DDPM-style). Discrete: mask-based diffusion (predict original token). """ import math from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 1e-5, 0.999) def q_sample_continuous(x0: torch.Tensor, t: torch.Tensor, alphas_cumprod: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Add Gaussian noise to continuous features at timestep t.""" noise = torch.randn_like(x0) a_bar = alphas_cumprod[t].view(-1, 1, 1) xt = torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise return xt, noise def q_sample_discrete( x0: torch.Tensor, t: torch.Tensor, mask_tokens: torch.Tensor, max_t: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly mask discrete tokens with a linear schedule over t.""" bsz = x0.size(0) p = t.float() / float(max_t) p = p.view(bsz, 1, 1) mask = torch.rand_like(x0.float()) < p x_masked = x0.clone() for i in range(x0.size(2)): x_masked[:, :, i][mask[:, :, i]] = mask_tokens[i] return x_masked, mask class SinusoidalTimeEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half) args = t.float().unsqueeze(1) * freqs.unsqueeze(0) emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1) if self.dim % 2 == 1: emb = F.pad(emb, (0, 1)) return emb class HybridDiffusionModel(nn.Module): def __init__( self, cont_dim: int, disc_vocab_sizes: List[int], time_dim: int = 64, hidden_dim: int = 256, ): super().__init__() self.cont_dim = cont_dim self.disc_vocab_sizes = disc_vocab_sizes self.time_embed = SinusoidalTimeEmbedding(time_dim) self.disc_embeds = nn.ModuleList([ nn.Embedding(vocab_size + 1, min(32, vocab_size * 2)) for vocab_size in disc_vocab_sizes ]) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim) self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) self.cont_head = nn.Linear(hidden_dim, cont_dim) self.disc_heads = nn.ModuleList([ nn.Linear(hidden_dim, vocab_size) for vocab_size in disc_vocab_sizes ]) def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor): """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" time_emb = self.time_embed(t) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) disc_embs = [] for i, emb in enumerate(self.disc_embeds): disc_embs.append(emb(x_disc[:, :, i])) disc_feat = torch.cat(disc_embs, dim=-1) cont_feat = self.cont_proj(x_cont) feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1) feat = self.in_proj(feat) out, _ = self.backbone(feat) eps_pred = self.cont_head(out) logits = [head(out) for head in self.disc_heads] return eps_pred, logits