Files
mask-ddpm/example/export_samples.py

184 lines
6.4 KiB
Python

#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model and export to CSV."""
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
with open(path, "r", encoding="ascii") as f:
return json.load(f)["vocab"]
def load_stats(path: str):
with open(path, "r", encoding="ascii") as f:
return json.load(f)
def read_header(path: str) -> List[str]:
if path.endswith(".gz"):
opener = gzip.open
mode = "rt"
else:
opener = open
mode = "r"
with opener(path, mode, newline="") as f:
reader = csv.reader(f)
return next(reader)
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
inv = {}
for col, mapping in vocab.items():
inverse = [""] * len(mapping)
for tok, idx in mapping.items():
inverse[idx] = tok
inv[col] = inverse
return inv
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
return parser.parse_args()
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def main():
args = parse_args()
if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path)
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vocab = load_vocab(args.vocab_path)
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
x_disc = torch.full(
(args.batch_size, args.seq_len, len(disc_cols)),
0,
device=device,
dtype=torch.long,
)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
header = read_header(args.data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="ascii") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
row[c] = ("%.6f" % float(x_cont[b, t, i]))
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
row[c] = tok
writer.writerow(row)
row_index += 1
print("exported_csv", args.out)
print("rows", args.batch_size * args.seq_len)
if __name__ == "__main__":
main()