Update example and notes
This commit is contained in:
88
example/sample.py
Executable file
88
example/sample.py
Executable file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
TIMESTEPS = 200
|
||||
SEQ_LEN = 64
|
||||
BATCH_SIZE = 2
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||
return json.load(f)["vocab"]
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
|
||||
vocab = load_vocab()
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
||||
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
|
||||
# Initialize discrete with mask tokens
|
||||
for i in range(len(disc_cols)):
|
||||
x_disc[:, :, i] = mask_tokens[i]
|
||||
|
||||
for t in reversed(range(TIMESTEPS)):
|
||||
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||
|
||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
coef1 = 1.0 / torch.sqrt(a_t)
|
||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||
if t > 0:
|
||||
noise = torch.randn_like(x_cont)
|
||||
x_cont = mean + torch.sqrt(betas[t]) * noise
|
||||
else:
|
||||
x_cont = mean
|
||||
|
||||
# Discrete: fill masked positions by sampling logits
|
||||
for i, logit in enumerate(logits):
|
||||
if t == 0:
|
||||
probs = F.softmax(logit, dim=-1)
|
||||
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
||||
else:
|
||||
mask = x_disc[:, :, i] == mask_tokens[i]
|
||||
if mask.any():
|
||||
probs = F.softmax(logit, dim=-1)
|
||||
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
||||
x_disc[:, :, i][mask] = sampled[mask]
|
||||
|
||||
print("sampled_cont_shape", tuple(x_cont.shape))
|
||||
print("sampled_disc_shape", tuple(x_disc.shape))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user