Files
mask-ddpm/example/export_samples.py

283 lines
11 KiB
Python

#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model and export to CSV."""
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
def load_stats(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def read_header(path: str) -> List[str]:
if path.endswith(".gz"):
opener = gzip.open
mode = "rt"
else:
opener = open
mode = "r"
with opener(path, mode, newline="") as f:
reader = csv.reader(f)
return next(reader)
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
inv = {}
for col, mapping in vocab.items():
inverse = [""] * len(mapping)
for tok, idx in mapping.items():
inverse[idx] = tok
inv[col] = inverse
return inv
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args()
# 使用 platform_utils 中的 resolve_device 函数
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.data_path = str(resolve_path(base_dir, args.data_path))
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
args.split_path = str(resolve_path(base_dir, args.split_path))
args.stats_path = str(resolve_path(base_dir, args.stats_path))
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
args.model_path = str(resolve_path(base_dir, args.model_path))
args.out = str(resolve_path(base_dir, args.out))
if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path)
# resolve header source
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
cfg = {}
use_condition = False
cond_vocab_size = 0
if args.config:
args.config = str(resolve_path(base_dir, args.config))
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
cfg_base = Path(args.config).resolve().parent
cfg_glob = cfg.get("data_glob", args.data_glob)
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
base = Path(cfg_glob).parent
pat = Path(cfg_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
else:
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
x_disc = torch.full(
(args.batch_size, args.seq_len, len(disc_cols)),
0,
device=device,
dtype=torch.long,
)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
# condition id
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
# move to CPU for export
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
# clip in normalized space to avoid extreme blow-up
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
# clamp to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is not None and hi is not None:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = (fmt % val)
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1
print("exported_csv", args.out)
print("rows", args.batch_size * args.seq_len)
if __name__ == "__main__":
main()