89 lines
2.9 KiB
Python
Executable File
89 lines
2.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
|
|
|
|
import json
|
|
import math
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from data_utils import load_split
|
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
|
|
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
|
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
|
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
TIMESTEPS = 200
|
|
SEQ_LEN = 64
|
|
BATCH_SIZE = 2
|
|
|
|
|
|
def load_vocab():
|
|
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
|
return json.load(f)["vocab"]
|
|
|
|
|
|
def main():
|
|
split = load_split(SPLIT_PATH)
|
|
cont_cols = split["continuous"]
|
|
disc_cols = split["discrete"]
|
|
|
|
vocab = load_vocab()
|
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
|
|
|
print("device", DEVICE)
|
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
|
if os.path.exists(MODEL_PATH):
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
|
|
model.eval()
|
|
|
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
|
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
|
|
|
# Initialize discrete with mask tokens
|
|
for i in range(len(disc_cols)):
|
|
x_disc[:, :, i] = mask_tokens[i]
|
|
|
|
for t in reversed(range(TIMESTEPS)):
|
|
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
|
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
|
|
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
|
a_t = alphas[t]
|
|
a_bar_t = alphas_cumprod[t]
|
|
coef1 = 1.0 / torch.sqrt(a_t)
|
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
|
if t > 0:
|
|
noise = torch.randn_like(x_cont)
|
|
x_cont = mean + torch.sqrt(betas[t]) * noise
|
|
else:
|
|
x_cont = mean
|
|
|
|
# Discrete: fill masked positions by sampling logits
|
|
for i, logit in enumerate(logits):
|
|
if t == 0:
|
|
probs = F.softmax(logit, dim=-1)
|
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
|
else:
|
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
|
if mask.any():
|
|
probs = F.softmax(logit, dim=-1)
|
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
|
x_disc[:, :, i][mask] = sampled[mask]
|
|
|
|
print("sampled_cont_shape", tuple(x_cont.shape))
|
|
print("sampled_disc_shape", tuple(x_disc.shape))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|