208 lines
8.3 KiB
Python
Executable File
208 lines
8.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
|
|
|
|
import json
|
|
import math
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from data_utils import load_split
|
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
|
from platform_utils import resolve_device, safe_path, ensure_dir
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
|
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
|
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
|
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
|
|
|
|
# 使用 platform_utils 中的 resolve_device 函数
|
|
|
|
|
|
DEVICE = resolve_device("auto")
|
|
TIMESTEPS = 200
|
|
SEQ_LEN = 64
|
|
BATCH_SIZE = 2
|
|
CLIP_K = 5.0
|
|
|
|
|
|
def load_vocab():
|
|
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
|
|
return json.load(f)["vocab"]
|
|
|
|
|
|
def main():
|
|
cfg = {}
|
|
if CONFIG_PATH.exists():
|
|
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
|
|
cfg = json.load(f)
|
|
timesteps = int(cfg.get("timesteps", TIMESTEPS))
|
|
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
|
|
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
|
clip_k = float(cfg.get("clip_k", CLIP_K))
|
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
|
cond_dim = int(cfg.get("cond_dim", 32))
|
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
|
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
|
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
|
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
|
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
|
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
|
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
|
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
|
cont_target = str(cfg.get("cont_target", "eps"))
|
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
|
model_time_dim = int(cfg.get("model_time_dim", 64))
|
|
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
|
model_num_layers = int(cfg.get("model_num_layers", 1))
|
|
model_dropout = float(cfg.get("model_dropout", 0.0))
|
|
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
|
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
|
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
|
backbone_type = str(cfg.get("backbone_type", "gru"))
|
|
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
|
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
|
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
|
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
|
|
|
split = load_split(str(SPLIT_PATH))
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
vocab = load_vocab()
|
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
|
|
|
print("device", DEVICE)
|
|
cond_vocab_size = 0
|
|
if use_condition:
|
|
data_glob = cfg.get("data_glob")
|
|
if data_glob:
|
|
base = Path(data_glob).parent
|
|
pat = Path(data_glob).name
|
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
|
model = HybridDiffusionModel(
|
|
cont_dim=len(cont_cols),
|
|
disc_vocab_sizes=vocab_sizes,
|
|
time_dim=model_time_dim,
|
|
hidden_dim=model_hidden_dim,
|
|
num_layers=model_num_layers,
|
|
dropout=model_dropout,
|
|
ff_mult=model_ff_mult,
|
|
pos_dim=model_pos_dim,
|
|
use_pos_embed=model_use_pos,
|
|
backbone_type=backbone_type,
|
|
transformer_num_layers=transformer_num_layers,
|
|
transformer_nhead=transformer_nhead,
|
|
transformer_ff_dim=transformer_ff_dim,
|
|
transformer_dropout=transformer_dropout,
|
|
cond_vocab_size=cond_vocab_size,
|
|
cond_dim=cond_dim,
|
|
use_tanh_eps=use_tanh_eps,
|
|
eps_scale=eps_scale,
|
|
).to(DEVICE)
|
|
if MODEL_PATH.exists():
|
|
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
|
model.eval()
|
|
|
|
temporal_model = None
|
|
if use_temporal_stage1:
|
|
if temporal_backbone == "transformer":
|
|
temporal_model = TemporalTransformerGenerator(
|
|
input_dim=len(cont_cols),
|
|
hidden_dim=temporal_hidden_dim,
|
|
num_layers=temporal_transformer_num_layers,
|
|
nhead=temporal_transformer_nhead,
|
|
ff_dim=temporal_transformer_ff_dim,
|
|
dropout=temporal_transformer_dropout,
|
|
pos_dim=temporal_pos_dim,
|
|
use_pos_embed=temporal_use_pos_embed,
|
|
).to(DEVICE)
|
|
else:
|
|
temporal_model = TemporalGRUGenerator(
|
|
input_dim=len(cont_cols),
|
|
hidden_dim=temporal_hidden_dim,
|
|
num_layers=temporal_num_layers,
|
|
dropout=temporal_dropout,
|
|
).to(DEVICE)
|
|
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
|
if not temporal_path.exists():
|
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
|
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
|
temporal_model.eval()
|
|
|
|
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
|
|
x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
|
|
|
# Initialize discrete with mask tokens
|
|
for i in range(len(disc_cols)):
|
|
x_disc[:, :, i] = mask_tokens[i]
|
|
|
|
cond = None
|
|
if use_condition:
|
|
if cond_vocab_size <= 0:
|
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
|
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
|
|
|
trend = None
|
|
if temporal_model is not None:
|
|
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
|
|
|
|
for t in reversed(range(timesteps)):
|
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
|
|
|
if cont_target == "x0":
|
|
x0_pred = eps_pred
|
|
if cont_clamp_x0 > 0:
|
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
|
|
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
|
a_t = alphas[t]
|
|
a_bar_t = alphas_cumprod[t]
|
|
coef1 = 1.0 / torch.sqrt(a_t)
|
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
|
if t > 0:
|
|
noise = torch.randn_like(x_cont)
|
|
x_cont = mean + torch.sqrt(betas[t]) * noise
|
|
else:
|
|
x_cont = mean
|
|
if clip_k > 0:
|
|
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
|
|
|
|
# Discrete: fill masked positions by sampling logits
|
|
for i, logit in enumerate(logits):
|
|
if t == 0:
|
|
probs = F.softmax(logit, dim=-1)
|
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
|
else:
|
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
|
if mask.any():
|
|
probs = F.softmax(logit, dim=-1)
|
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
|
x_disc[:, :, i][mask] = sampled[mask]
|
|
|
|
if trend is not None:
|
|
x_cont = x_cont + trend
|
|
print("sampled_cont_shape", tuple(x_cont.shape))
|
|
print("sampled_disc_shape", tuple(x_disc.shape))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|