Files
mask-ddpm/example/export_samples_resume.py
2026-04-18 19:01:25 +08:00

441 lines
20 KiB
Python

#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model with routing-aware export fixes."""
from __future__ import annotations
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real
from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
TemporalTransformerGenerator,
cosine_beta_schedule,
)
from platform_utils import resolve_device, resolve_path
from submission_type_utils import (
denormalize_cont_tensor,
resolve_routing_features,
resolve_taxonomy_features,
)
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args()
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.data_path = str(resolve_path(base_dir, args.data_path))
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
args.split_path = str(resolve_path(base_dir, args.split_path))
args.stats_path = str(resolve_path(base_dir, args.stats_path))
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
args.model_path = str(resolve_path(base_dir, args.model_path))
args.out = str(resolve_path(base_dir, args.out))
if not os.path.exists(args.model_path):
raise SystemExit(f"missing model file: {args.model_path}")
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
quantile_raw_values = stats.get("quantile_raw_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
cfg = {}
use_condition = False
cond_vocab_size = 0
if args.config:
args.config = str(resolve_path(base_dir, args.config))
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
cfg_base = Path(args.config).resolve().parent
cfg_glob = cfg.get("data_glob", args.data_glob)
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
base = Path(cfg_glob).parent
pat = Path(cfg_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
if cond_vocab_size <= 0:
raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}")
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_quantile = bool(cfg.get("use_quantile_transform", False))
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features")
route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features")
type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features")
model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
model = HybridDiffusionModel(
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_cont_dim=len(route_type1_cols),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(load_torch_state(ema_path, device))
else:
model.load_state_dict(load_torch_state(args.model_path, device))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_state = load_torch_state(str(temporal_path), device)
temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0
if isinstance(temporal_state, dict):
if "in_proj.weight" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
except Exception:
pass
elif "gru.weight_ih_l0" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
except Exception:
pass
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
cond_dim=temporal_cond_dim,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
cond_dim=temporal_cond_dim,
).to(device)
temporal_model.load_state_dict(temporal_state)
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
cond_cont = None
if route_type1_cols:
ref_glob = cfg.get("data_glob") or args.data_glob
if ref_glob:
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
base = Path(ref_glob).parent
pat = Path(ref_glob).name
refs = sorted(base.glob(pat))
if refs:
ref_path = refs[0]
ref_rows = []
with gzip.open(ref_path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
ref_rows.append(row)
if len(ref_rows) >= args.seq_len:
seq = ref_rows[: args.seq_len]
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device)
for t, row in enumerate(seq):
for i, c in enumerate(route_type1_cols):
cond_cont[:, t, i] = float(row[c])
cond_cont = normalize_cont(
cond_cont,
route_type1_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
if temporal_focus_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 1.0
trend = trend * trend_mask
elif temporal_exclude_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 0.0
trend = trend * trend_mask
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
if cont_target == "x0":
x0_pred = eps_pred
if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
if use_quantile:
q_vals = {c: quantile_values[c] for c in model_cont_cols}
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
else:
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(model_cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
if cont_post_calibrate and quantile_raw_values and quantile_probs:
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
if vmin and vmax:
for i, c in enumerate(model_cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is None or hi is None:
continue
lo = float(lo)
hi = float(hi)
if cont_bound_mode == "none":
continue
if cont_bound_mode == "sigmoid":
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
elif cont_bound_mode == "soft_tanh":
mid = 0.5 * (lo + hi)
half = 0.5 * (hi - lo)
denom = cont_bound_strength if cont_bound_strength > 0 else 1.0
x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom)
else:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi)
if cont_post_scale:
for i, c in enumerate(model_cont_cols):
if c in cont_post_scale:
try:
scale = float(cont_post_scale[c])
except Exception:
scale = 1.0
x_cont[:, :, i] = x_cont[:, :, i] * scale
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
for i, c in enumerate(model_cont_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = x_cont[:, :, i]
if cond_cont is not None and route_type1_cols:
cond_denorm = denormalize_cont_tensor(
cond_cont.cpu(),
route_type1_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
for i, c in enumerate(route_type1_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
for c in route_type5_cols:
if c.endswith("Z"):
base = c[:-1]
if base in cont_cols:
bidx = cont_cols.index(base)
cidx = cont_cols.index(c)
full_cont[:, :, cidx] = full_cont[:, :, bidx]
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(full_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = fmt % val
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1
if __name__ == "__main__":
main()