#!/usr/bin/env python3 """Sample from a trained hybrid diffusion model with routing-aware export fixes.""" from __future__ import annotations import argparse import csv import gzip import json import os from pathlib import Path import torch import torch.nn.functional as F from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header from hybrid_diffusion import ( HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule, ) from platform_utils import resolve_device, resolve_path from submission_type_utils import ( denormalize_cont_tensor, resolve_routing_features, resolve_taxonomy_features, ) def parse_args(): parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.") base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt")) parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--timesteps", type=int, default=200) parser.add_argument("--seq-len", type=int, default=64) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std") parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available") parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning") parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random") parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV") return parser.parse_args() def main(): args = parse_args() base_dir = Path(__file__).resolve().parent args.data_path = str(resolve_path(base_dir, args.data_path)) args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else "" args.split_path = str(resolve_path(base_dir, args.split_path)) args.stats_path = str(resolve_path(base_dir, args.stats_path)) args.vocab_path = str(resolve_path(base_dir, args.vocab_path)) args.model_path = str(resolve_path(base_dir, args.model_path)) args.out = str(resolve_path(base_dir, args.out)) if not os.path.exists(args.model_path): raise SystemExit(f"missing model file: {args.model_path}") data_path = args.data_path if args.data_glob: base = Path(args.data_glob).parent pat = Path(args.data_glob).name matches = sorted(base.glob(pat)) if matches: data_path = str(matches[0]) split = load_split(args.split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_stats(args.stats_path) mean = stats["mean"] std = stats["std"] vmin = stats.get("min", {}) vmax = stats.get("max", {}) int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) transforms = stats.get("transform", {}) quantile_probs = stats.get("quantile_probs") quantile_values = stats.get("quantile_values") quantile_raw_values = stats.get("quantile_raw_values") vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] top_token = vocab_json.get("top_token", {}) inv_vocab = build_inverse_vocab(vocab) vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(args.device) cfg = {} use_condition = False cond_vocab_size = 0 if args.config: args.config = str(resolve_path(base_dir, args.config)) if args.config and os.path.exists(args.config): with open(args.config, "r", encoding="utf-8") as f: cfg = json.load(f) use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" if use_condition: cfg_base = Path(args.config).resolve().parent cfg_glob = cfg.get("data_glob", args.data_glob) cfg_glob = str(resolve_path(cfg_base, cfg_glob)) base = Path(cfg_glob).parent pat = Path(cfg_glob).name cond_vocab_size = len(sorted(base.glob(pat))) if cond_vocab_size <= 0: raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}") cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) use_quantile = bool(cfg.get("use_quantile_transform", False)) cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp")) cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0)) cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {} cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False)) route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features") route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features") type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features") model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols] use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False)) temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False)) temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False)) temporal_backbone = str(cfg.get("temporal_backbone", "gru")) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64)) temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True)) temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2)) temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4)) temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512)) temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1)) backbone_type = str(cfg.get("backbone_type", "gru")) transformer_num_layers = int(cfg.get("transformer_num_layers", 2)) transformer_nhead = int(cfg.get("transformer_nhead", 4)) transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512)) transformer_dropout = float(cfg.get("transformer_dropout", 0.1)) model = HybridDiffusionModel( cont_dim=len(model_cont_cols), disc_vocab_sizes=vocab_sizes, time_dim=int(cfg.get("model_time_dim", 64)), hidden_dim=int(cfg.get("model_hidden_dim", 256)), num_layers=int(cfg.get("model_num_layers", 1)), dropout=float(cfg.get("model_dropout", 0.0)), ff_mult=int(cfg.get("model_ff_mult", 2)), pos_dim=int(cfg.get("model_pos_dim", 64)), use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), backbone_type=backbone_type, transformer_num_layers=transformer_num_layers, transformer_nhead=transformer_nhead, transformer_ff_dim=transformer_ff_dim, transformer_dropout=transformer_dropout, cond_cont_dim=len(route_type1_cols), cond_vocab_size=cond_vocab_size if use_condition else 0, cond_dim=int(cfg.get("cond_dim", 32)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), eps_scale=float(cfg.get("eps_scale", 1.0)), ).to(device) if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")): ema_path = args.model_path.replace("model.pt", "model_ema.pt") model.load_state_dict(load_torch_state(ema_path, device)) else: model.load_state_dict(load_torch_state(args.model_path, device)) model.eval() temporal_model = None if use_temporal_stage1: temporal_path = Path(args.model_path).with_name("temporal.pt") if not temporal_path.exists(): raise SystemExit(f"missing temporal model file: {temporal_path}") temporal_state = load_torch_state(str(temporal_path), device) temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0 if isinstance(temporal_state, dict): if "in_proj.weight" in temporal_state: try: temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols)) except Exception: pass elif "gru.weight_ih_l0" in temporal_state: try: temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols)) except Exception: pass if temporal_backbone == "transformer": temporal_model = TemporalTransformerGenerator( input_dim=len(model_cont_cols), hidden_dim=temporal_hidden_dim, num_layers=temporal_transformer_num_layers, nhead=temporal_transformer_nhead, ff_dim=temporal_transformer_ff_dim, dropout=temporal_transformer_dropout, pos_dim=temporal_pos_dim, use_pos_embed=temporal_use_pos_embed, cond_dim=temporal_cond_dim, ).to(device) else: temporal_model = TemporalGRUGenerator( input_dim=len(model_cont_cols), hidden_dim=temporal_hidden_dim, num_layers=temporal_num_layers, dropout=temporal_dropout, cond_dim=temporal_cond_dim, ).to(device) temporal_model.load_state_dict(temporal_state) temporal_model.eval() betas = cosine_beta_schedule(args.timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device) x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long) mask_tokens = torch.tensor(vocab_sizes, device=device) for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] cond = None if use_condition: if cond_vocab_size <= 0: raise SystemExit("use_condition enabled but no files matched data_glob") if args.condition_id < 0: cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device) else: cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond = cond_id cond_cont = None if route_type1_cols: ref_glob = cfg.get("data_glob") or args.data_glob if ref_glob: ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob base = Path(ref_glob).parent pat = Path(ref_glob).name refs = sorted(base.glob(pat)) if refs: ref_path = refs[0] ref_rows = [] with gzip.open(ref_path, "rt", newline="") as fh: reader = csv.DictReader(fh) for row in reader: ref_rows.append(row) if len(ref_rows) >= args.seq_len: seq = ref_rows[: args.seq_len] cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device) for t, row in enumerate(seq): for i, c in enumerate(route_type1_cols): cond_cont[:, t, i] = float(row[c]) cond_cont = normalize_cont( cond_cont, route_type1_cols, mean, std, transforms=transforms, quantile_probs=quantile_probs, quantile_values=quantile_values, use_quantile=use_quantile, ) trend = None if temporal_model is not None: trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont) if temporal_focus_type4 and type4_cols: type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols] if type4_model_idx: trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype) trend_mask[:, :, type4_model_idx] = 1.0 trend = trend * trend_mask elif temporal_exclude_type4 and type4_cols: type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols] if type4_model_idx: trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype) trend_mask[:, :, type4_model_idx] = 0.0 trend = trend * trend_mask for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont) a_t = alphas[t] a_bar_t = alphas_cumprod[t] if cont_target == "x0": x0_pred = eps_pred if cont_clamp_x0 > 0: x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) mean_x = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) x_cont = mean_x + torch.sqrt(betas[t]) * noise else: x_cont = mean_x if args.clip_k > 0: x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) for i, logit in enumerate(logits): if t == 0: probs = F.softmax(logit, dim=-1) x_disc[:, :, i] = torch.argmax(probs, dim=-1) else: mask = x_disc[:, :, i] == mask_tokens[i] if mask.any(): probs = F.softmax(logit, dim=-1) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view( args.batch_size, args.seq_len ) x_disc[:, :, i][mask] = sampled[mask] if trend is not None: x_cont = x_cont + trend x_cont = x_cont.cpu() x_disc = x_disc.cpu() if args.clip_k > 0: x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) if use_quantile: q_vals = {c: quantile_values[c] for c in model_cont_cols} x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals) else: mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec for i, c in enumerate(model_cont_cols): if transforms.get(c) == "log1p": x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) if cont_post_calibrate and quantile_raw_values and quantile_probs: q_raw = {c: quantile_raw_values[c] for c in model_cont_cols} x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw) if vmin and vmax: for i, c in enumerate(model_cont_cols): lo = vmin.get(c, None) hi = vmax.get(c, None) if lo is None or hi is None: continue lo = float(lo) hi = float(hi) if cont_bound_mode == "none": continue if cont_bound_mode == "sigmoid": x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i]) elif cont_bound_mode == "soft_tanh": mid = 0.5 * (lo + hi) half = 0.5 * (hi - lo) denom = cont_bound_strength if cont_bound_strength > 0 else 1.0 x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom) else: x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi) if cont_post_scale: for i, c in enumerate(model_cont_cols): if c in cont_post_scale: try: scale = float(cont_post_scale[c]) except Exception: scale = 1.0 x_cont[:, :, i] = x_cont[:, :, i] * scale full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype) for i, c in enumerate(model_cont_cols): full_idx = cont_cols.index(c) full_cont[:, :, full_idx] = x_cont[:, :, i] if cond_cont is not None and route_type1_cols: cond_denorm = denormalize_cont_tensor( cond_cont.cpu(), route_type1_cols, mean, std, transforms=transforms, quantile_probs=quantile_probs, quantile_values=quantile_values, use_quantile=use_quantile, ) for i, c in enumerate(route_type1_cols): full_idx = cont_cols.index(c) full_cont[:, :, full_idx] = cond_denorm[:, :, i] for c in route_type5_cols: if c.endswith("Z"): base = c[:-1] if base in cont_cols: bidx = cont_cols.index(base) cidx = cont_cols.index(c) full_cont[:, :, cidx] = full_cont[:, :, bidx] header = read_header(data_path) out_cols = [c for c in header if c != time_col or args.include_time] if args.include_condition and use_condition: out_cols = ["__cond_file_id"] + out_cols os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=out_cols) writer.writeheader() row_index = 0 for b in range(args.batch_size): for t in range(args.seq_len): row = {} if args.include_condition and use_condition: row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): val = float(full_cont[b, t, i]) if int_like.get(c, False): row[c] = str(int(round(val))) else: dec = int(max_decimals.get(c, 6)) fmt = ("%%.%df" % dec) if dec > 0 else "%.0f" row[c] = fmt % val for i, c in enumerate(disc_cols): tok_idx = int(x_disc[b, t, i]) tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "" if tok == "" and c in top_token: tok = top_token[c] row[c] = tok writer.writerow(row) row_index += 1 if __name__ == "__main__": main()