Files
mask-ddpm/example/export_samples.py
2026-01-22 17:39:31 +08:00

175 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model and export to CSV."""
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
def load_stats(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def read_header(path: str) -> List[str]:
if path.endswith(".gz"):
opener = gzip.open
mode = "rt"
else:
opener = open
mode = "r"
with opener(path, mode, newline="") as f:
reader = csv.reader(f)
return next(reader)
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
inv = {}
for col, mapping in vocab.items():
inverse = [""] * len(mapping)
for tok, idx in mapping.items():
inverse[idx] = tok
inv[col] = inverse
return inv
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
return parser.parse_args()
# 使用 platform_utils 中的 resolve_device 函数
def main():
args = parse_args()
if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path)
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vocab = load_vocab(args.vocab_path)
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
x_disc = torch.full(
(args.batch_size, args.seq_len, len(disc_cols)),
0,
device=device,
dtype=torch.long,
)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
header = read_header(args.data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
row[c] = ("%.6f" % float(x_cont[b, t, i]))
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
row[c] = tok
writer.writerow(row)
row_index += 1
print("exported_csv", args.out)
print("rows", args.batch_size * args.seq_len)
if __name__ == "__main__":
main()