Files
mask-ddpm/example/sample.py
Mingzhe Yang 10c0721ee1 update
2026-02-04 03:53:17 +08:00

215 lines
8.4 KiB
Python
Executable File

#!/usr/bin/env python3
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
import json
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
SPLIT_PATH = BASE_DIR / "feature_split.json"
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
MODEL_PATH = BASE_DIR / "results" / "model.pt"
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
# 使用 platform_utils 中的 resolve_device 函数
DEVICE = resolve_device("auto")
TIMESTEPS = 200
SEQ_LEN = 64
BATCH_SIZE = 2
CLIP_K = 5.0
def load_torch_state(path: str, device: str):
try:
return torch.load(path, map_location=device, weights_only=True)
except TypeError:
return torch.load(path, map_location=device)
def load_vocab():
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
def main():
cfg = {}
if CONFIG_PATH.exists():
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
cfg = json.load(f)
timesteps = int(cfg.get("timesteps", TIMESTEPS))
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
clip_k = float(cfg.get("clip_k", CLIP_K))
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
model_num_layers = int(cfg.get("model_num_layers", 1))
model_dropout = float(cfg.get("model_dropout", 0.0))
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
vocab = load_vocab()
vocab_sizes = [len(vocab[c]) for c in disc_cols]
print("device", DEVICE)
cond_vocab_size = 0
if use_condition:
data_glob = cfg.get("data_glob")
if data_glob:
base = Path(data_glob).parent
pat = Path(data_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=model_time_dim,
hidden_dim=model_hidden_dim,
num_layers=model_num_layers,
dropout=model_dropout,
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,
eps_scale=eps_scale,
).to(DEVICE)
if MODEL_PATH.exists():
model.load_state_dict(load_torch_state(str(MODEL_PATH), DEVICE))
model.eval()
temporal_model = None
if use_temporal_stage1:
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
).to(DEVICE)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(load_torch_state(str(temporal_path), DEVICE))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
# Initialize discrete with mask tokens
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
if cont_target == "x0":
x0_pred = eps_pred
if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
# Continuous reverse step (DDPM): x_{t-1} mean
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean + torch.sqrt(betas[t]) * noise
else:
x_cont = mean
if clip_k > 0:
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
# Discrete: fill masked positions by sampling logits
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape))
if __name__ == "__main__":
main()