#!/usr/bin/env python3 """Train hybrid diffusion on HAI (configurable runnable example).""" import argparse import json import os import random from pathlib import Path from typing import Dict import torch import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DEFAULTS = { "data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"), "split_path": str(BASE_DIR / "feature_split.json"), "stats_path": str(BASE_DIR / "results" / "cont_stats.json"), "vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"), "out_dir": str(BASE_DIR / "results"), "device": "auto", "timesteps": 1000, "batch_size": 8, "seq_len": 64, "epochs": 1, "max_batches": 50, "lambda": 0.5, "lr": 1e-3, "seed": 1337, "log_every": 10, "ckpt_every": 50, } def load_json(path: str) -> Dict: with open(path, "r", encoding="ascii") as f: return json.load(f) def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def resolve_device(mode: str) -> str: mode = mode.lower() if mode == "cpu": return "cpu" if mode == "cuda": if not torch.cuda.is_available(): raise SystemExit("device set to cuda but CUDA is not available") return "cuda" if torch.cuda.is_available(): return "cuda" return "cpu" def parse_args(): parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.") parser.add_argument("--config", default=None, help="Path to JSON config.") return parser.parse_args() def resolve_config_paths(config, base_dir: Path): keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"] for key in keys: if key in config: path = Path(str(config[key])) if not path.is_absolute(): config[key] = str((base_dir / path).resolve()) return config def main(): args = parse_args() config = dict(DEFAULTS) if args.config: cfg_path = Path(args.config).resolve() config.update(load_json(str(cfg_path))) config = resolve_config_paths(config, cfg_path.parent) else: config = resolve_config_paths(config, BASE_DIR) set_seed(int(config["seed"])) split = load_split(config["split_path"]) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(str(config["device"])) print("device", device) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) betas = cosine_beta_schedule(int(config["timesteps"])).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(config["out_dir"], exist_ok=True) log_path = os.path.join(config["out_dir"], "train_log.csv") with open(log_path, "w", encoding="ascii") as f: f.write("epoch,step,loss,loss_cont,loss_disc\n") total_step = 0 for epoch in range(int(config["epochs"])): for step, (x_cont, x_disc) in enumerate( windowed_batches( config["data_path"], cont_cols, disc_cols, vocab, mean, std, batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), ) ): x_cont = x_cont.to(device) x_disc = x_disc.to(device) bsz = x_cont.size(0) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) eps_pred, logits = model(x_cont_t, x_disc_t, t) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc opt.zero_grad() loss.backward() opt.step() if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) with open(log_path, "a", encoding="ascii") as f: f.write( "%d,%d,%.6f,%.6f,%.6f\n" % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) ) total_step += 1 if total_step % int(config["ckpt_every"]) == 0: ckpt = { "model": model.state_dict(), "optim": opt.state_dict(), "config": config, "step": total_step, } torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt")) torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt")) if __name__ == "__main__": main()