#!/usr/bin/env python3 """Train hybrid diffusion on HAI (configurable runnable example).""" import argparse import json import os import random from pathlib import Path from typing import Dict import torch import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DEFAULTS = { "data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz", "data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz", "split_path": BASE_DIR / "feature_split.json", "stats_path": BASE_DIR / "results" / "cont_stats.json", "vocab_path": BASE_DIR / "results" / "disc_vocab.json", "out_dir": BASE_DIR / "results", "device": "auto", "timesteps": 1000, "batch_size": 8, "seq_len": 64, "epochs": 1, "max_batches": 50, "lambda": 0.5, "lr": 1e-3, "seed": 1337, "log_every": 10, "ckpt_every": 50, "ema_decay": 0.999, "use_ema": True, "clip_k": 5.0, "grad_clip": 1.0, "use_condition": True, "condition_type": "file_id", "cond_dim": 32, "use_tanh_eps": True, "eps_scale": 1.0, } def load_json(path: str) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 使用 platform_utils 中的 resolve_device 函数 def parse_args(): parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.") parser.add_argument("--config", default=None, help="Path to JSON config.") parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") return parser.parse_args() def resolve_config_paths(config, base_dir: Path): keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"] for key in keys: if key in config: # 如果值是字符串,转换为Path对象 if isinstance(config[key], str): path_str = config[key] # glob pattern cannot be Path.resolve()'d on Windows if "*" in path_str or "?" in path_str or "[" in path_str: config[key] = str((base_dir / Path(path_str))) continue path = Path(path_str) else: path = config[key] if not path.is_absolute(): config[key] = str(resolve_path(base_dir, path)) else: config[key] = str(path) return config class EMA: def __init__(self, model, decay: float): self.decay = decay self.shadow = {} for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.detach().clone() def update(self, model): with torch.no_grad(): for name, param in model.named_parameters(): if not param.requires_grad: continue old = self.shadow[name] self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay) def state_dict(self): return self.shadow def main(): args = parse_args() config = dict(DEFAULTS) if args.config: cfg_path = Path(args.config).resolve() config.update(load_json(str(cfg_path))) config = resolve_config_paths(config, cfg_path.parent) else: config = resolve_config_paths(config, BASE_DIR) # 优先使用命令行传入的device参数 if args.device != "auto": config["device"] = args.device set_seed(int(config["seed"])) split = load_split(config["split_path"]) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] data_paths = None if "data_glob" in config and config["data_glob"]: data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name)) if data_paths: data_paths = [safe_path(p) for p in data_paths] if not data_paths: data_paths = [safe_path(config["data_path"])] use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id" cond_vocab_size = len(data_paths) if use_condition else 0 device = resolve_device(str(config["device"])) print("device", device) model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, cond_vocab_size=cond_vocab_size, cond_dim=int(config.get("cond_dim", 32)), use_tanh_eps=bool(config.get("use_tanh_eps", False)), eps_scale=float(config.get("eps_scale", 1.0)), ).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None betas = cosine_beta_schedule(int(config["timesteps"])).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(config["out_dir"], exist_ok=True) out_dir = safe_path(config["out_dir"]) log_path = os.path.join(out_dir, "train_log.csv") with open(log_path, "w", encoding="utf-8") as f: f.write("epoch,step,loss,loss_cont,loss_disc\n") with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: json.dump(config, f, indent=2) total_step = 0 for epoch in range(int(config["epochs"])): for step, batch in enumerate( windowed_batches( data_paths, cont_cols, disc_cols, vocab, mean, std, batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), return_file_id=use_condition, ) ): if use_condition: x_cont, x_disc, cond = batch cond = cond.to(device) else: x_cont, x_disc = batch cond = None x_cont = x_cont.to(device) x_disc = x_disc.to(device) bsz = x_cont.size(0) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) loss_disc_count += 1 if loss_disc_count > 0: loss_disc = loss_disc / loss_disc_count lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) opt.step() if ema is not None: ema.update(model) if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) with open(log_path, "a", encoding="utf-8") as f: f.write( "%d,%d,%.6f,%.6f,%.6f\n" % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) ) total_step += 1 if total_step % int(config["ckpt_every"]) == 0: ckpt = { "model": model.state_dict(), "optim": opt.state_dict(), "config": config, "step": total_step, } if ema is not None: ckpt["ema"] = ema.state_dict() torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) if ema is not None: torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) if __name__ == "__main__": main()