#!/usr/bin/env python3 """Sampling stub for hybrid diffusion (continuous + discrete).""" import json import math import os from pathlib import Path import torch import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent SPLIT_PATH = BASE_DIR / "feature_split.json" VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json" MODEL_PATH = BASE_DIR / "results" / "model.pt" CONFIG_PATH = BASE_DIR / "results" / "config_used.json" # 使用 platform_utils 中的 resolve_device 函数 DEVICE = resolve_device("auto") TIMESTEPS = 200 SEQ_LEN = 64 BATCH_SIZE = 2 CLIP_K = 5.0 def load_torch_state(path: str, device: str): try: return torch.load(path, map_location=device, weights_only=True) except TypeError: return torch.load(path, map_location=device) def load_vocab(): with open(str(VOCAB_PATH), "r", encoding="utf-8") as f: return json.load(f)["vocab"] def main(): cfg = {} if CONFIG_PATH.exists(): with open(str(CONFIG_PATH), "r", encoding="utf-8") as f: cfg = json.load(f) timesteps = int(cfg.get("timesteps", TIMESTEPS)) seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN))) batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE))) clip_k = float(cfg.get("clip_k", CLIP_K)) use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" cond_dim = int(cfg.get("cond_dim", 32)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) temporal_backbone = str(cfg.get("temporal_backbone", "gru")) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64)) temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True)) temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2)) temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4)) temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512)) temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1)) cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) model_time_dim = int(cfg.get("model_time_dim", 64)) model_hidden_dim = int(cfg.get("model_hidden_dim", 256)) model_num_layers = int(cfg.get("model_num_layers", 1)) model_dropout = float(cfg.get("model_dropout", 0.0)) model_ff_mult = int(cfg.get("model_ff_mult", 2)) model_pos_dim = int(cfg.get("model_pos_dim", 64)) model_use_pos = bool(cfg.get("model_use_pos_embed", True)) backbone_type = str(cfg.get("backbone_type", "gru")) transformer_num_layers = int(cfg.get("transformer_num_layers", 2)) transformer_nhead = int(cfg.get("transformer_nhead", 4)) transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512)) transformer_dropout = float(cfg.get("transformer_dropout", 0.1)) split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] vocab = load_vocab() vocab_sizes = [len(vocab[c]) for c in disc_cols] print("device", DEVICE) cond_vocab_size = 0 if use_condition: data_glob = cfg.get("data_glob") if data_glob: base = Path(data_glob).parent pat = Path(data_glob).name cond_vocab_size = len(sorted(base.glob(pat))) model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, time_dim=model_time_dim, hidden_dim=model_hidden_dim, num_layers=model_num_layers, dropout=model_dropout, ff_mult=model_ff_mult, pos_dim=model_pos_dim, use_pos_embed=model_use_pos, backbone_type=backbone_type, transformer_num_layers=transformer_num_layers, transformer_nhead=transformer_nhead, transformer_ff_dim=transformer_ff_dim, transformer_dropout=transformer_dropout, cond_vocab_size=cond_vocab_size, cond_dim=cond_dim, use_tanh_eps=use_tanh_eps, eps_scale=eps_scale, ).to(DEVICE) if MODEL_PATH.exists(): model.load_state_dict(load_torch_state(str(MODEL_PATH), DEVICE)) model.eval() temporal_model = None if use_temporal_stage1: if temporal_backbone == "transformer": temporal_model = TemporalTransformerGenerator( input_dim=len(cont_cols), hidden_dim=temporal_hidden_dim, num_layers=temporal_transformer_num_layers, nhead=temporal_transformer_nhead, ff_dim=temporal_transformer_ff_dim, dropout=temporal_transformer_dropout, pos_dim=temporal_pos_dim, use_pos_embed=temporal_use_pos_embed, ).to(DEVICE) else: temporal_model = TemporalGRUGenerator( input_dim=len(cont_cols), hidden_dim=temporal_hidden_dim, num_layers=temporal_num_layers, dropout=temporal_dropout, ).to(DEVICE) temporal_path = BASE_DIR / "results" / "temporal.pt" if not temporal_path.exists(): raise SystemExit(f"missing temporal model file: {temporal_path}") temporal_model.load_state_dict(load_torch_state(str(temporal_path), DEVICE)) temporal_model.eval() betas = cosine_beta_schedule(timesteps).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE) x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) # Initialize discrete with mask tokens for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] cond = None if use_condition: if cond_vocab_size <= 0: raise SystemExit("use_condition enabled but no files matched data_glob") cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) trend = None if temporal_model is not None: trend = temporal_model.generate(batch_size, seq_len, DEVICE) for t in reversed(range(timesteps)): t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) if cont_target == "x0": x0_pred = eps_pred if cont_clamp_x0 > 0: x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) # Continuous reverse step (DDPM): x_{t-1} mean a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean + torch.sqrt(betas[t]) * noise else: x_cont = mean if clip_k > 0: x_cont = torch.clamp(x_cont, -clip_k, clip_k) # Discrete: fill masked positions by sampling logits for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) x_disc[:, :, i][mask] = sampled[mask] if trend is not None: x_cont = x_cont + trend print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_disc_shape", tuple(x_disc.shape)) if __name__ == "__main__": main()