#!/usr/bin/env python3 """Train hybrid diffusion on HAI (configurable runnable example).""" import argparse import json import os import random from pathlib import Path from typing import Dict import torch import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) from platform_utils import resolve_device, safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DEFAULTS = { "data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz", "split_path": BASE_DIR / "feature_split.json", "stats_path": BASE_DIR / "results" / "cont_stats.json", "vocab_path": BASE_DIR / "results" / "disc_vocab.json", "out_dir": BASE_DIR / "results", "device": "auto", "timesteps": 1000, "batch_size": 8, "seq_len": 64, "epochs": 1, "max_batches": 50, "lambda": 0.5, "lr": 1e-3, "seed": 1337, "log_every": 10, "ckpt_every": 50, } def load_json(path: str) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 使用 platform_utils 中的 resolve_device 函数 def parse_args(): parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.") parser.add_argument("--config", default=None, help="Path to JSON config.") parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") return parser.parse_args() def resolve_config_paths(config, base_dir: Path): keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"] for key in keys: if key in config: # 如果值是字符串,转换为Path对象 if isinstance(config[key], str): path = Path(config[key]) else: path = config[key] if not path.is_absolute(): config[key] = str((base_dir / path).resolve()) else: config[key] = str(path) return config def main(): args = parse_args() config = dict(DEFAULTS) if args.config: cfg_path = Path(args.config).resolve() config.update(load_json(str(cfg_path))) config = resolve_config_paths(config, cfg_path.parent) else: config = resolve_config_paths(config, BASE_DIR) # 优先使用命令行传入的device参数 if args.device != "auto": config["device"] = args.device set_seed(int(config["seed"])) split = load_split(config["split_path"]) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(str(config["device"])) print("device", device) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) betas = cosine_beta_schedule(int(config["timesteps"])).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(config["out_dir"], exist_ok=True) log_path = os.path.join(config["out_dir"], "train_log.csv") with open(log_path, "w", encoding="utf-8") as f: f.write("epoch,step,loss,loss_cont,loss_disc\n") total_step = 0 for epoch in range(int(config["epochs"])): for step, (x_cont, x_disc) in enumerate( windowed_batches( config["data_path"], cont_cols, disc_cols, vocab, mean, std, batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), ) ): x_cont = x_cont.to(device) x_disc = x_disc.to(device) bsz = x_cont.size(0) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) eps_pred, logits = model(x_cont_t, x_disc_t, t) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc opt.zero_grad() loss.backward() opt.step() if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) with open(log_path, "a", encoding="utf-8") as f: f.write( "%d,%d,%.6f,%.6f,%.6f\n" % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) ) total_step += 1 if total_step % int(config["ckpt_every"]) == 0: ckpt = { "model": model.state_dict(), "optim": opt.state_dict(), "config": config, "step": total_step, } torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt")) torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt")) if __name__ == "__main__": main()