Update example and notes

This commit is contained in:
2026-01-09 02:14:20 +08:00
parent 200bdf6136
commit c0639386be
18 changed files with 31656 additions and 0 deletions

88
example/sample.py Executable file
View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
import json
import math
import os
import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TIMESTEPS = 200
SEQ_LEN = 64
BATCH_SIZE = 2
def load_vocab():
with open(VOCAB_PATH, "r", encoding="ascii") as f:
return json.load(f)["vocab"]
def main():
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
vocab = load_vocab()
vocab_sizes = [len(vocab[c]) for c in disc_cols]
print("device", DEVICE)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
if os.path.exists(MODEL_PATH):
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
model.eval()
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
# Initialize discrete with mask tokens
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
for t in reversed(range(TIMESTEPS)):
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch)
# Continuous reverse step (DDPM): x_{t-1} mean
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean + torch.sqrt(betas[t]) * noise
else:
x_cont = mean
# Discrete: fill masked positions by sampling logits
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask]
print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape))
if __name__ == "__main__":
main()