#!/usr/bin/env python3 """Sample from a trained hybrid diffusion model and export to CSV.""" import argparse import csv import gzip import json import os from pathlib import Path from typing import Dict, List import torch import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path def load_vocab(path: str) -> Dict[str, Dict[str, int]]: with open(path, "r", encoding="utf-8") as f: return json.load(f)["vocab"] def load_stats(path: str): with open(path, "r", encoding="utf-8") as f: return json.load(f) def read_header(path: str) -> List[str]: if path.endswith(".gz"): opener = gzip.open mode = "rt" else: opener = open mode = "r" with opener(path, mode, newline="") as f: reader = csv.reader(f) return next(reader) def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]: inv = {} for col, mapping in vocab.items(): inverse = [""] * len(mapping) for tok, idx in mapping.items(): inverse[idx] = tok inv[col] = inverse return inv def parse_args(): parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.") base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt")) parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--timesteps", type=int, default=200) parser.add_argument("--seq-len", type=int, default=64) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std") parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available") parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning") parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random") parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV") return parser.parse_args() # 使用 platform_utils 中的 resolve_device 函数 def main(): args = parse_args() base_dir = Path(__file__).resolve().parent args.data_path = str(resolve_path(base_dir, args.data_path)) args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else "" args.split_path = str(resolve_path(base_dir, args.split_path)) args.stats_path = str(resolve_path(base_dir, args.stats_path)) args.vocab_path = str(resolve_path(base_dir, args.vocab_path)) args.model_path = str(resolve_path(base_dir, args.model_path)) args.out = str(resolve_path(base_dir, args.out)) if not os.path.exists(args.model_path): raise SystemExit("missing model file: %s" % args.model_path) # resolve header source data_path = args.data_path if args.data_glob: base = Path(args.data_glob).parent pat = Path(args.data_glob).name matches = sorted(base.glob(pat)) if matches: data_path = str(matches[0]) split = load_split(args.split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_stats(args.stats_path) mean = stats["mean"] std = stats["std"] vmin = stats.get("min", {}) vmax = stats.get("max", {}) int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) transforms = stats.get("transform", {}) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] top_token = vocab_json.get("top_token", {}) inv_vocab = build_inverse_vocab(vocab) vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(args.device) cfg = {} use_condition = False cond_vocab_size = 0 if args.config: args.config = str(resolve_path(base_dir, args.config)) if args.config and os.path.exists(args.config): with open(args.config, "r", encoding="utf-8") as f: cfg = json.load(f) use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" if use_condition: cfg_base = Path(args.config).resolve().parent cfg_glob = cfg.get("data_glob", args.data_glob) cfg_glob = str(resolve_path(cfg_base, cfg_glob)) base = Path(cfg_glob).parent pat = Path(cfg_glob).name cond_vocab_size = len(sorted(base.glob(pat))) if cond_vocab_size <= 0: raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, time_dim=int(cfg.get("model_time_dim", 64)), hidden_dim=int(cfg.get("model_hidden_dim", 256)), num_layers=int(cfg.get("model_num_layers", 1)), dropout=float(cfg.get("model_dropout", 0.0)), ff_mult=int(cfg.get("model_ff_mult", 2)), pos_dim=int(cfg.get("model_pos_dim", 64)), use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size if use_condition else 0, cond_dim=int(cfg.get("cond_dim", 32)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), eps_scale=float(cfg.get("eps_scale", 1.0)), ).to(device) if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")): ema_path = args.model_path.replace("model.pt", "model_ema.pt") model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True)) else: model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.eval() betas = cosine_beta_schedule(args.timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device) x_disc = torch.full( (args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long, ) mask_tokens = torch.tensor(vocab_sizes, device=device) for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] # condition id cond = None if use_condition: if cond_vocab_size <= 0: raise SystemExit("use_condition enabled but no files matched data_glob") if args.condition_id < 0: cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device) else: cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond = cond_id for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) a_t = alphas[t] a_bar_t = alphas_cumprod[t] if cont_target == "x0": x0_pred = eps_pred if cont_clamp_x0 > 0: x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) elif cont_target == "v": v_pred = eps_pred x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean_x = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean_x + torch.sqrt(betas[t]) * noise else: x_cont = mean_x if args.clip_k > 0: x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view( args.batch_size, args.seq_len ) x_disc[:, :, i][mask] = sampled[mask] # move to CPU for export x_cont = x_cont.cpu() x_disc = x_disc.cpu() # clip in normalized space to avoid extreme blow-up if args.clip_k > 0: x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) # clamp to observed min/max per feature if vmin and vmax: for i, c in enumerate(cont_cols): lo = vmin.get(c, None) hi = vmax.get(c, None) if lo is not None and hi is not None: x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi)) header = read_header(data_path) out_cols = [c for c in header if c != time_col or args.include_time] if args.include_condition and use_condition: out_cols = ["__cond_file_id"] + out_cols os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=out_cols) writer.writeheader() row_index = 0 for b in range(args.batch_size): for t in range(args.seq_len): row = {} if args.include_condition and use_condition: row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): val = float(x_cont[b, t, i]) if int_like.get(c, False): row[c] = str(int(round(val))) else: dec = int(max_decimals.get(c, 6)) fmt = ("%%.%df" % dec) if dec > 0 else "%.0f" row[c] = (fmt % val) for i, c in enumerate(disc_cols): tok_idx = int(x_disc[b, t, i]) tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "" if tok == "" and c in top_token: tok = top_token[c] row[c] = tok writer.writerow(row) row_index += 1 print("exported_csv", args.out) print("rows", args.batch_size * args.seq_len) if __name__ == "__main__": main()