#!/usr/bin/env python3 """Hybrid diffusion scaffold for continuous + discrete HAI features. Continuous: Gaussian diffusion (DDPM-style). Discrete: mask-based diffusion (predict original token). """ import math from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 1e-5, 0.999) def q_sample_continuous(x0: torch.Tensor, t: torch.Tensor, alphas_cumprod: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Add Gaussian noise to continuous features at timestep t.""" noise = torch.randn_like(x0) a_bar = alphas_cumprod[t].view(-1, 1, 1) xt = torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise return xt, noise def q_sample_discrete( x0: torch.Tensor, t: torch.Tensor, mask_tokens: torch.Tensor, max_t: int, mask_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly mask discrete tokens with a cosine schedule over t.""" bsz = x0.size(0) # cosine schedule: p(0)=0, p(max_t)=1 p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t))) if mask_scale != 1.0: p = torch.clamp(p * mask_scale, 0.0, 1.0) p = p.view(bsz, 1, 1) mask = torch.rand_like(x0.float()) < p x_masked = x0.clone() for i in range(x0.size(2)): x_masked[:, :, i][mask[:, :, i]] = mask_tokens[i] return x_masked, mask class SinusoidalTimeEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half) args = t.float().unsqueeze(1) * freqs.unsqueeze(0) emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1) if self.dim % 2 == 1: emb = F.pad(emb, (0, 1)) return emb class TemporalGRUGenerator(nn.Module): """Stage-1 temporal generator for sequence trend.""" def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0): super().__init__() self.start_token = nn.Parameter(torch.zeros(input_dim)) self.gru = nn.GRU( input_dim, hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0.0, batch_first=True, ) self.out = nn.Linear(hidden_dim, input_dim) def forward_teacher(self, x: torch.Tensor) -> torch.Tensor: """Teacher-forced next-step prediction. Returns trend sequence and preds.""" if x.size(1) < 2: raise ValueError("sequence length must be >= 2 for teacher forcing") inp = x[:, :-1, :] out, _ = self.gru(inp) pred_next = self.out(out) trend = torch.zeros_like(x) trend[:, 0, :] = x[:, 0, :] trend[:, 1:, :] = pred_next return trend, pred_next def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor: """Autoregressively generate a backbone sequence.""" h = None prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device) outputs = [] for _ in range(seq_len): out, h = self.gru(prev.unsqueeze(1), h) nxt = self.out(out.squeeze(1)) outputs.append(nxt.unsqueeze(1)) prev = nxt return torch.cat(outputs, dim=1) class HybridDiffusionModel(nn.Module): def __init__( self, cont_dim: int, disc_vocab_sizes: List[int], time_dim: int = 64, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0, ff_mult: int = 2, pos_dim: int = 64, use_pos_embed: bool = True, backbone_type: str = "gru", # gru | transformer transformer_num_layers: int = 4, transformer_nhead: int = 8, transformer_ff_dim: int = 2048, transformer_dropout: float = 0.1, cond_vocab_size: int = 0, cond_dim: int = 32, use_tanh_eps: bool = False, eps_scale: float = 1.0, ): super().__init__() self.cont_dim = cont_dim self.disc_vocab_sizes = disc_vocab_sizes self.time_embed = SinusoidalTimeEmbedding(time_dim) self.use_tanh_eps = use_tanh_eps self.eps_scale = eps_scale self.pos_dim = pos_dim self.use_pos_embed = use_pos_embed self.backbone_type = backbone_type self.cond_vocab_size = cond_vocab_size self.cond_dim = cond_dim self.cond_embed = None if cond_vocab_size and cond_vocab_size > 0: self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim) self.disc_embeds = nn.ModuleList([ nn.Embedding(vocab_size + 1, min(32, vocab_size * 2)) for vocab_size in disc_vocab_sizes ]) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) pos_dim = pos_dim if use_pos_embed else 0 in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) self.in_proj = nn.Linear(in_dim, hidden_dim) if backbone_type == "transformer": encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=transformer_nhead, dim_feedforward=transformer_ff_dim, dropout=transformer_dropout, batch_first=True, activation="gelu", ) self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=transformer_num_layers) else: self.backbone = nn.GRU( hidden_dim, hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0.0, batch_first=True, ) self.post_norm = nn.LayerNorm(hidden_dim) self.post_ff = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * ff_mult), nn.GELU(), nn.Linear(hidden_dim * ff_mult, hidden_dim), ) self.cont_head = nn.Linear(hidden_dim, cont_dim) self.disc_heads = nn.ModuleList([ nn.Linear(hidden_dim, vocab_size) for vocab_size in disc_vocab_sizes ]) def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None): """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" time_emb = self.time_embed(t) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) pos_emb = None if self.use_pos_embed and self.pos_dim > 0: pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device) disc_embs = [] for i, emb in enumerate(self.disc_embeds): disc_embs.append(emb(x_disc[:, :, i])) disc_feat = torch.cat(disc_embs, dim=-1) cond_feat = None if self.cond_embed is not None: if cond is None: raise ValueError("cond is required when cond_vocab_size > 0") cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1) cont_feat = self.cont_proj(x_cont) parts = [cont_feat, disc_feat, time_emb] if pos_emb is not None: parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) if cond_feat is not None: parts.append(cond_feat) feat = torch.cat(parts, dim=-1) feat = self.in_proj(feat) if self.backbone_type == "transformer": out = self.backbone(feat) else: out, _ = self.backbone(feat) out = self.post_norm(out) out = out + self.post_ff(out) eps_pred = self.cont_head(out) if self.use_tanh_eps: eps_pred = torch.tanh(eps_pred) * self.eps_scale logits = [head(out) for head in self.disc_heads] return eps_pred, logits @staticmethod def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor: pos = torch.arange(seq_len, device=device).float().unsqueeze(1) div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim)) pe = torch.zeros(seq_len, dim, device=device) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) return pe