#!/usr/bin/env python3 """Sample from a trained hybrid diffusion model and export to CSV.""" import argparse import csv import gzip import json import os from pathlib import Path from typing import Dict, List import torch import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir def load_vocab(path: str) -> Dict[str, Dict[str, int]]: with open(path, "r", encoding="utf-8") as f: return json.load(f)["vocab"] def load_stats(path: str): with open(path, "r", encoding="utf-8") as f: return json.load(f) def read_header(path: str) -> List[str]: if path.endswith(".gz"): opener = gzip.open mode = "rt" else: opener = open mode = "r" with opener(path, mode, newline="") as f: reader = csv.reader(f) return next(reader) def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]: inv = {} for col, mapping in vocab.items(): inverse = [""] * len(mapping) for tok, idx in mapping.items(): inverse[idx] = tok inv[col] = inverse return inv def parse_args(): parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.") base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt")) parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--timesteps", type=int, default=200) parser.add_argument("--seq-len", type=int, default=64) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") return parser.parse_args() # 使用 platform_utils 中的 resolve_device 函数 def main(): args = parse_args() if not os.path.exists(args.model_path): raise SystemExit("missing model file: %s" % args.model_path) split = load_split(args.split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_stats(args.stats_path) mean = stats["mean"] std = stats["std"] vocab = load_vocab(args.vocab_path) inv_vocab = build_inverse_vocab(vocab) vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(args.device) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.eval() betas = cosine_beta_schedule(args.timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device) x_disc = torch.full( (args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long, ) mask_tokens = torch.tensor(vocab_sizes, device=device) for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch) a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean_x = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean_x + torch.sqrt(betas[t]) * noise else: x_cont = mean_x for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view( args.batch_size, args.seq_len ) x_disc[:, :, i][mask] = sampled[mask] x_cont = x_cont.cpu() x_disc = x_disc.cpu() mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec header = read_header(args.data_path) out_cols = [c for c in header if c != time_col or args.include_time] os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=out_cols) writer.writeheader() row_index = 0 for b in range(args.batch_size): for t in range(args.seq_len): row = {} if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): row[c] = ("%.6f" % float(x_cont[b, t, i])) for i, c in enumerate(disc_cols): tok_idx = int(x_disc[b, t, i]) tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0" row[c] = tok writer.writerow(row) row_index += 1 print("exported_csv", args.out) print("rows", args.batch_size * args.seq_len) if __name__ == "__main__": main()