298 lines
12 KiB
Python
298 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Sample from a trained hybrid diffusion model and export to CSV."""
|
|
|
|
import argparse
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from data_utils import load_split
|
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
|
|
|
|
|
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)["vocab"]
|
|
|
|
|
|
def load_stats(path: str):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def read_header(path: str) -> List[str]:
|
|
if path.endswith(".gz"):
|
|
opener = gzip.open
|
|
mode = "rt"
|
|
else:
|
|
opener = open
|
|
mode = "r"
|
|
with opener(path, mode, newline="") as f:
|
|
reader = csv.reader(f)
|
|
return next(reader)
|
|
|
|
|
|
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
|
|
inv = {}
|
|
for col, mapping in vocab.items():
|
|
inverse = [""] * len(mapping)
|
|
for tok, idx in mapping.items():
|
|
inverse[idx] = tok
|
|
inv[col] = inverse
|
|
return inv
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
repo_dir = base_dir.parent.parent
|
|
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
|
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
|
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
|
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
|
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
|
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
|
|
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
|
|
parser.add_argument("--timesteps", type=int, default=200)
|
|
parser.add_argument("--seq-len", type=int, default=64)
|
|
parser.add_argument("--batch-size", type=int, default=2)
|
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
|
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
|
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
|
|
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
|
|
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
|
|
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
|
|
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
|
|
return parser.parse_args()
|
|
|
|
|
|
# 使用 platform_utils 中的 resolve_device 函数
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_dir = Path(__file__).resolve().parent
|
|
args.data_path = str(resolve_path(base_dir, args.data_path))
|
|
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
|
|
args.split_path = str(resolve_path(base_dir, args.split_path))
|
|
args.stats_path = str(resolve_path(base_dir, args.stats_path))
|
|
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
|
|
args.model_path = str(resolve_path(base_dir, args.model_path))
|
|
args.out = str(resolve_path(base_dir, args.out))
|
|
|
|
if not os.path.exists(args.model_path):
|
|
raise SystemExit("missing model file: %s" % args.model_path)
|
|
|
|
# resolve header source
|
|
data_path = args.data_path
|
|
if args.data_glob:
|
|
base = Path(args.data_glob).parent
|
|
pat = Path(args.data_glob).name
|
|
matches = sorted(base.glob(pat))
|
|
if matches:
|
|
data_path = str(matches[0])
|
|
|
|
split = load_split(args.split_path)
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
stats = load_stats(args.stats_path)
|
|
mean = stats["mean"]
|
|
std = stats["std"]
|
|
vmin = stats.get("min", {})
|
|
vmax = stats.get("max", {})
|
|
int_like = stats.get("int_like", {})
|
|
max_decimals = stats.get("max_decimals", {})
|
|
transforms = stats.get("transform", {})
|
|
|
|
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
|
vocab = vocab_json["vocab"]
|
|
top_token = vocab_json.get("top_token", {})
|
|
inv_vocab = build_inverse_vocab(vocab)
|
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
|
|
|
device = resolve_device(args.device)
|
|
cfg = {}
|
|
use_condition = False
|
|
cond_vocab_size = 0
|
|
if args.config:
|
|
args.config = str(resolve_path(base_dir, args.config))
|
|
if args.config and os.path.exists(args.config):
|
|
with open(args.config, "r", encoding="utf-8") as f:
|
|
cfg = json.load(f)
|
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
|
if use_condition:
|
|
cfg_base = Path(args.config).resolve().parent
|
|
cfg_glob = cfg.get("data_glob", args.data_glob)
|
|
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
|
|
base = Path(cfg_glob).parent
|
|
pat = Path(cfg_glob).name
|
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
|
if cond_vocab_size <= 0:
|
|
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
|
cont_target = str(cfg.get("cont_target", "eps"))
|
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
|
|
|
model = HybridDiffusionModel(
|
|
cont_dim=len(cont_cols),
|
|
disc_vocab_sizes=vocab_sizes,
|
|
time_dim=int(cfg.get("model_time_dim", 64)),
|
|
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
|
num_layers=int(cfg.get("model_num_layers", 1)),
|
|
dropout=float(cfg.get("model_dropout", 0.0)),
|
|
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
|
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
|
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
|
use_feature_graph=bool(cfg.get("model_use_feature_graph", False)),
|
|
feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)),
|
|
feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)),
|
|
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
|
cond_dim=int(cfg.get("cond_dim", 32)),
|
|
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
|
eps_scale=float(cfg.get("eps_scale", 1.0)),
|
|
).to(device)
|
|
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
|
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
|
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
|
|
else:
|
|
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
|
model.eval()
|
|
|
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
|
|
x_disc = torch.full(
|
|
(args.batch_size, args.seq_len, len(disc_cols)),
|
|
0,
|
|
device=device,
|
|
dtype=torch.long,
|
|
)
|
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
|
for i in range(len(disc_cols)):
|
|
x_disc[:, :, i] = mask_tokens[i]
|
|
|
|
# condition id
|
|
cond = None
|
|
if use_condition:
|
|
if cond_vocab_size <= 0:
|
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
|
if args.condition_id < 0:
|
|
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
|
|
else:
|
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
|
cond = cond_id
|
|
|
|
for t in reversed(range(args.timesteps)):
|
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
|
|
|
a_t = alphas[t]
|
|
a_bar_t = alphas_cumprod[t]
|
|
|
|
if cont_target == "x0":
|
|
x0_pred = eps_pred
|
|
if cont_clamp_x0 > 0:
|
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
|
elif cont_target == "v":
|
|
v_pred = eps_pred
|
|
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
|
|
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
|
|
coef1 = 1.0 / torch.sqrt(a_t)
|
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
|
if t > 0:
|
|
noise = torch.randn_like(x_cont)
|
|
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
|
else:
|
|
x_cont = mean_x
|
|
if args.clip_k > 0:
|
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
|
|
|
for i, logit in enumerate(logits):
|
|
if t == 0:
|
|
probs = F.softmax(logit, dim=-1)
|
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
|
else:
|
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
|
if mask.any():
|
|
probs = F.softmax(logit, dim=-1)
|
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
|
|
args.batch_size, args.seq_len
|
|
)
|
|
x_disc[:, :, i][mask] = sampled[mask]
|
|
|
|
# move to CPU for export
|
|
x_cont = x_cont.cpu()
|
|
x_disc = x_disc.cpu()
|
|
|
|
# clip in normalized space to avoid extreme blow-up
|
|
if args.clip_k > 0:
|
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
|
|
|
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
|
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
|
x_cont = x_cont * std_vec + mean_vec
|
|
for i, c in enumerate(cont_cols):
|
|
if transforms.get(c) == "log1p":
|
|
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
|
# clamp to observed min/max per feature
|
|
if vmin and vmax:
|
|
for i, c in enumerate(cont_cols):
|
|
lo = vmin.get(c, None)
|
|
hi = vmax.get(c, None)
|
|
if lo is not None and hi is not None:
|
|
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
|
|
|
|
header = read_header(data_path)
|
|
out_cols = [c for c in header if c != time_col or args.include_time]
|
|
if args.include_condition and use_condition:
|
|
out_cols = ["__cond_file_id"] + out_cols
|
|
|
|
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
|
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=out_cols)
|
|
writer.writeheader()
|
|
|
|
row_index = 0
|
|
for b in range(args.batch_size):
|
|
for t in range(args.seq_len):
|
|
row = {}
|
|
if args.include_condition and use_condition:
|
|
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
|
if args.include_time and time_col in header:
|
|
row[time_col] = str(row_index)
|
|
for i, c in enumerate(cont_cols):
|
|
val = float(x_cont[b, t, i])
|
|
if int_like.get(c, False):
|
|
row[c] = str(int(round(val)))
|
|
else:
|
|
dec = int(max_decimals.get(c, 6))
|
|
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
|
|
row[c] = (fmt % val)
|
|
for i, c in enumerate(disc_cols):
|
|
tok_idx = int(x_disc[b, t, i])
|
|
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
|
|
if tok == "<UNK>" and c in top_token:
|
|
tok = top_token[c]
|
|
row[c] = tok
|
|
writer.writerow(row)
|
|
row_index += 1
|
|
|
|
print("exported_csv", args.out)
|
|
print("rows", args.batch_size * args.seq_len)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|