Files
mask-ddpm/example/export_samples.py
2026-02-04 02:40:57 +08:00

435 lines
19 KiB
Python

#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model and export to CSV."""
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
def load_stats(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def read_header(path: str) -> List[str]:
if path.endswith(".gz"):
opener = gzip.open
mode = "rt"
else:
opener = open
mode = "r"
with opener(path, mode, newline="") as f:
reader = csv.reader(f)
return next(reader)
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
inv = {}
for col, mapping in vocab.items():
inverse = [""] * len(mapping)
for tok, idx in mapping.items():
inverse[idx] = tok
inv[col] = inverse
return inv
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args()
# 使用 platform_utils 中的 resolve_device 函数
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.data_path = str(resolve_path(base_dir, args.data_path))
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
args.split_path = str(resolve_path(base_dir, args.split_path))
args.stats_path = str(resolve_path(base_dir, args.stats_path))
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
args.model_path = str(resolve_path(base_dir, args.model_path))
args.out = str(resolve_path(base_dir, args.out))
if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path)
# resolve header source
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
quantile_raw_values = stats.get("quantile_raw_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
cfg = {}
use_condition = False
cond_vocab_size = 0
if args.config:
args.config = str(resolve_path(base_dir, args.config))
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
cfg_base = Path(args.config).resolve().parent
cfg_glob = cfg.get("data_glob", args.data_glob)
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
base = Path(cfg_glob).parent
pat = Path(cfg_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_quantile = bool(cfg.get("use_quantile_transform", False))
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
type1_cols = cfg.get("type1_features", []) or []
type5_cols = cfg.get("type5_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
model = HybridDiffusionModel(
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_cont_dim=len(type1_cols),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
else:
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
x_disc = torch.full(
(args.batch_size, args.seq_len, len(disc_cols)),
0,
device=device,
dtype=torch.long,
)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
# condition id
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
# type1 program conditioning (library replay)
cond_cont = None
if type1_cols:
ref_glob = cfg.get("data_glob") or args.data_glob
if ref_glob:
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
base = Path(ref_glob).parent
pat = Path(ref_glob).name
refs = sorted(base.glob(pat))
if refs:
ref_path = refs[0]
ref_rows = []
with gzip.open(ref_path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
ref_rows.append(row)
if len(ref_rows) >= args.seq_len:
seq = ref_rows[: args.seq_len]
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(type1_cols), device=device)
for t, row in enumerate(seq):
for i, c in enumerate(type1_cols):
cond_cont[:, t, i] = float(row[c])
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
cond_cont = (cond_cont - mean_vec) / std_vec
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
if cont_target == "x0":
x0_pred = eps_pred
if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
# clip in normalized space to avoid extreme blow-up
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
if use_quantile:
q_vals = {c: quantile_values[c] for c in model_cont_cols}
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
else:
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(model_cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
if cont_post_calibrate and quantile_raw_values and quantile_probs:
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
# bound to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(model_cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is None or hi is None:
continue
lo = float(lo)
hi = float(hi)
if cont_bound_mode == "none":
continue
if cont_bound_mode == "sigmoid":
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
elif cont_bound_mode == "soft_tanh":
# Soft bound without hard piling at edges
mid = 0.5 * (lo + hi)
half = 0.5 * (hi - lo)
denom = cont_bound_strength if cont_bound_strength > 0 else 1.0
x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom)
else:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi)
# optional post-scaling for problematic features
if cont_post_scale:
for i, c in enumerate(model_cont_cols):
if c in cont_post_scale:
try:
scale = float(cont_post_scale[c])
except Exception:
scale = 1.0
x_cont[:, :, i] = x_cont[:, :, i] * scale
# assemble full continuous output
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
for i, c in enumerate(model_cont_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = x_cont[:, :, i]
if cond_cont is not None and type1_cols:
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype)
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype)
cond_denorm = cond_cont.cpu() * std_vec + mean_vec
for i, c in enumerate(type1_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
for c in type5_cols:
if c.endswith("Z"):
base = c[:-1]
if base in cont_cols:
bidx = cont_cols.index(base)
cidx = cont_cols.index(c)
full_cont[:, :, cidx] = full_cont[:, :, bidx]
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(full_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = (fmt % val)
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1
print("exported_csv", args.out)
print("rows", args.batch_size * args.seq_len)
if __name__ == "__main__":
main()