#!/usr/bin/env python3 """Sampling stub for hybrid diffusion (continuous + discrete).""" import json import math import os from pathlib import Path import torch import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent SPLIT_PATH = BASE_DIR / "feature_split.json" VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json" MODEL_PATH = BASE_DIR / "results" / "model.pt" # 使用 platform_utils 中的 resolve_device 函数 DEVICE = resolve_device("auto") TIMESTEPS = 200 SEQ_LEN = 64 BATCH_SIZE = 2 def load_vocab(): with open(str(VOCAB_PATH), "r", encoding="utf-8") as f: return json.load(f)["vocab"] def main(): split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] vocab = load_vocab() vocab_sizes = [len(vocab[c]) for c in disc_cols] print("device", DEVICE) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) if MODEL_PATH.exists(): model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.eval() betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE) x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) # Initialize discrete with mask tokens for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] for t in reversed(range(TIMESTEPS)): t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch) # Continuous reverse step (DDPM): x_{t-1} mean a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean + torch.sqrt(betas[t]) * noise else: x_cont = mean # Discrete: fill masked positions by sampling logits for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) x_disc[:, :, i][mask] = sampled[mask] print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_disc_shape", tuple(x_disc.shape)) if __name__ == "__main__": main()