#!/usr/bin/env python3 """Sample from a trained hybrid diffusion model and export to CSV.""" import argparse import csv import gzip import json import os from pathlib import Path from typing import Dict, List import torch import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule def load_vocab(path: str) -> Dict[str, Dict[str, int]]: with open(path, "r", encoding="ascii") as f: return json.load(f)["vocab"] def load_stats(path: str): with open(path, "r", encoding="ascii") as f: return json.load(f) def read_header(path: str) -> List[str]: if path.endswith(".gz"): opener = gzip.open mode = "rt" else: opener = open mode = "r" with opener(path, mode, newline="") as f: reader = csv.reader(f) return next(reader) def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]: inv = {} for col, mapping in vocab.items(): inverse = [""] * len(mapping) for tok, idx in mapping.items(): inverse[idx] = tok inv[col] = inverse return inv def parse_args(): parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.") base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt")) parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--timesteps", type=int, default=200) parser.add_argument("--seq-len", type=int, default=64) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") return parser.parse_args() def resolve_device(mode: str) -> str: mode = mode.lower() if mode == "cpu": return "cpu" if mode == "cuda": if not torch.cuda.is_available(): raise SystemExit("device set to cuda but CUDA is not available") return "cuda" if torch.cuda.is_available(): return "cuda" return "cpu" def main(): args = parse_args() if not os.path.exists(args.model_path): raise SystemExit("missing model file: %s" % args.model_path) split = load_split(args.split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_stats(args.stats_path) mean = stats["mean"] std = stats["std"] vocab = load_vocab(args.vocab_path) inv_vocab = build_inverse_vocab(vocab) vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(args.device) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.eval() betas = cosine_beta_schedule(args.timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device) x_disc = torch.full( (args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long, ) mask_tokens = torch.tensor(vocab_sizes, device=device) for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch) a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean_x = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean_x + torch.sqrt(betas[t]) * noise else: x_cont = mean_x for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view( args.batch_size, args.seq_len ) x_disc[:, :, i][mask] = sampled[mask] x_cont = x_cont.cpu() x_disc = x_disc.cpu() mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec header = read_header(args.data_path) out_cols = [c for c in header if c != time_col or args.include_time] os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w", newline="", encoding="ascii") as f: writer = csv.DictWriter(f, fieldnames=out_cols) writer.writeheader() row_index = 0 for b in range(args.batch_size): for t in range(args.seq_len): row = {} if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): row[c] = ("%.6f" % float(x_cont[b, t, i])) for i, c in enumerate(disc_cols): tok_idx = int(x_disc[b, t, i]) tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0" row[c] = tok writer.writerow(row) row_index += 1 print("exported_csv", args.out) print("rows", args.batch_size * args.seq_len) if __name__ == "__main__": main()