#!/usr/bin/env python3 """Train hybrid diffusion on HAI (configurable runnable example).""" import argparse import json import os import random from pathlib import Path from typing import Dict import torch import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DEFAULTS = { "data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz", "data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz", "split_path": BASE_DIR / "feature_split.json", "stats_path": BASE_DIR / "results" / "cont_stats.json", "vocab_path": BASE_DIR / "results" / "disc_vocab.json", "out_dir": BASE_DIR / "results", "device": "auto", "timesteps": 1000, "batch_size": 8, "seq_len": 64, "epochs": 1, "max_batches": 50, "lambda": 0.5, "lr": 1e-3, "seed": 1337, "log_every": 10, "ckpt_every": 50, "ema_decay": 0.999, "use_ema": True, "clip_k": 5.0, "grad_clip": 1.0, "use_condition": True, "condition_type": "file_id", "cond_dim": 32, "use_tanh_eps": False, "eps_scale": 1.0, "model_time_dim": 128, "model_hidden_dim": 512, "model_num_layers": 2, "model_dropout": 0.1, "model_ff_mult": 2, "model_pos_dim": 64, "model_use_pos_embed": True, "disc_mask_scale": 0.9, "shuffle_buffer": 256, "cont_loss_weighting": "none", # none | inv_std "cont_loss_eps": 1e-6, "cont_target": "eps", # eps | x0 "cont_clamp_x0": 0.0, "use_temporal_stage1": True, "temporal_hidden_dim": 256, "temporal_num_layers": 1, "temporal_dropout": 0.0, "temporal_epochs": 2, "temporal_lr": 1e-3, "quantile_loss_weight": 0.0, "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "snr_weighted_loss": True, "snr_gamma": 1.0, "residual_stat_weight": 0.0, } def load_json(path: str) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 使用 platform_utils 中的 resolve_device 函数 def parse_args(): parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.") parser.add_argument("--config", default=None, help="Path to JSON config.") parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") return parser.parse_args() def resolve_config_paths(config, base_dir: Path): keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"] for key in keys: if key in config: # 如果值是字符串,转换为Path对象 if isinstance(config[key], str): path_str = config[key] # glob pattern cannot be Path.resolve()'d on Windows if "*" in path_str or "?" in path_str or "[" in path_str: config[key] = str((base_dir / Path(path_str))) continue path = Path(path_str) else: path = config[key] if not path.is_absolute(): config[key] = str(resolve_path(base_dir, path)) else: config[key] = str(path) return config class EMA: def __init__(self, model, decay: float): self.decay = decay self.shadow = {} for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.detach().clone() def update(self, model): with torch.no_grad(): for name, param in model.named_parameters(): if not param.requires_grad: continue old = self.shadow[name] self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay) def state_dict(self): return self.shadow def main(): args = parse_args() if args.config: print("using_config", str(Path(args.config).resolve())) config = dict(DEFAULTS) if args.config: cfg_path = Path(args.config).resolve() config.update(load_json(str(cfg_path))) config = resolve_config_paths(config, cfg_path.parent) else: config = resolve_config_paths(config, BASE_DIR) # 优先使用命令行传入的device参数 if args.device != "auto": config["device"] = args.device set_seed(int(config["seed"])) split = load_split(config["split_path"]) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] transforms = stats.get("transform", {}) raw_std = stats.get("raw_std", std) vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] data_paths = None if "data_glob" in config and config["data_glob"]: data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name)) if data_paths: data_paths = [safe_path(p) for p in data_paths] if not data_paths: data_paths = [safe_path(config["data_path"])] use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id" cond_vocab_size = len(data_paths) if use_condition else 0 device = resolve_device(str(config["device"])) print("device", device) model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, time_dim=int(config.get("model_time_dim", 64)), hidden_dim=int(config.get("model_hidden_dim", 256)), num_layers=int(config.get("model_num_layers", 1)), dropout=float(config.get("model_dropout", 0.0)), ff_mult=int(config.get("model_ff_mult", 2)), pos_dim=int(config.get("model_pos_dim", 64)), use_pos_embed=bool(config.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size, cond_dim=int(config.get("cond_dim", 32)), use_tanh_eps=bool(config.get("use_tanh_eps", False)), eps_scale=float(config.get("eps_scale", 1.0)), ).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) temporal_model = None opt_temporal = None if bool(config.get("use_temporal_stage1", False)): temporal_model = TemporalGRUGenerator( input_dim=len(cont_cols), hidden_dim=int(config.get("temporal_hidden_dim", 256)), num_layers=int(config.get("temporal_num_layers", 1)), dropout=float(config.get("temporal_dropout", 0.0)), ).to(device) opt_temporal = torch.optim.Adam( temporal_model.parameters(), lr=float(config.get("temporal_lr", config["lr"])), ) ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None betas = cosine_beta_schedule(int(config["timesteps"])).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(config["out_dir"], exist_ok=True) out_dir = safe_path(config["out_dir"]) log_path = os.path.join(out_dir, "train_log.csv") with open(log_path, "w", encoding="utf-8") as f: f.write("epoch,step,loss,loss_cont,loss_disc\n") with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: json.dump(config, f, indent=2) if temporal_model is not None and opt_temporal is not None: for epoch in range(int(config.get("temporal_epochs", 1))): for step, batch in enumerate( windowed_batches( data_paths, cont_cols, disc_cols, vocab, mean, std, batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), return_file_id=False, transforms=transforms, shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): x_cont, _ = batch x_cont = x_cont.to(device) trend, pred_next = temporal_model.forward_teacher(x_cont) temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :]) opt_temporal.zero_grad() temporal_loss.backward() if float(config.get("grad_clip", 0.0)) > 0: torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"])) opt_temporal.step() if step % int(config["log_every"]) == 0: print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss)) torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt")) total_step = 0 for epoch in range(int(config["epochs"])): for step, batch in enumerate( windowed_batches( data_paths, cont_cols, disc_cols, vocab, mean, std, batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), return_file_id=use_condition, transforms=transforms, shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): if use_condition: x_cont, x_disc, cond = batch cond = cond.to(device) else: x_cont, x_disc = batch cond = None x_cont = x_cont.to(device) x_disc = x_disc.to(device) trend = None if temporal_model is not None: temporal_model.eval() with torch.no_grad(): trend, _ = temporal_model.forward_teacher(x_cont) x_cont_resid = x_cont if trend is None else x_cont - trend bsz = x_cont.size(0) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete( x_disc, t, mask_tokens, int(config["timesteps"]), mask_scale=float(config.get("disc_mask_scale", 1.0)), ) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) cont_target = str(config.get("cont_target", "eps")) if cont_target == "x0": x0_target = x_cont_resid if float(config.get("cont_clamp_x0", 0.0)) > 0: x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) loss_base = (eps_pred - x0_target) ** 2 else: loss_base = (eps_pred - noise) ** 2 if config.get("cont_loss_weighting") == "inv_std": weights = torch.tensor( [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols], device=device, dtype=eps_pred.dtype, ).view(1, 1, -1) loss_cont = (loss_base * weights).mean() else: loss_cont = loss_base.mean() if bool(config.get("snr_weighted_loss", False)): a_bar_t = alphas_cumprod[t].view(-1, 1, 1) snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8) gamma = float(config.get("snr_gamma", 1.0)) snr_weight = snr / (snr + gamma) loss_cont = (loss_cont * snr_weight.mean()).mean() loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) loss_disc_count += 1 if loss_disc_count > 0: loss_disc = loss_disc / loss_disc_count lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc q_weight = float(config.get("quantile_loss_weight", 0.0)) if q_weight > 0: q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) a_bar_t = alphas_cumprod[t].view(-1, 1, 1) # Quantile loss on residual distribution x_real = x_cont_resid if cont_target == "x0": x_gen = eps_pred else: x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) x_real = x_real.view(-1, x_real.size(-1)) x_gen = x_gen.view(-1, x_gen.size(-1)) q_real = torch.quantile(x_real, q_tensor, dim=0) q_gen = torch.quantile(x_gen, q_tensor, dim=0) quantile_loss = torch.mean(torch.abs(q_gen - q_real)) loss = loss + q_weight * quantile_loss stat_weight = float(config.get("residual_stat_weight", 0.0)) if stat_weight > 0: # residual distribution matching (mean/std) a_bar_t = alphas_cumprod[t].view(-1, 1, 1) if cont_target == "x0": x_gen = eps_pred else: x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) x_real = x_cont_resid mean_real = x_real.mean(dim=(0, 1)) mean_gen = x_gen.mean(dim=(0, 1)) std_real = x_real.std(dim=(0, 1)) std_gen = x_gen.std(dim=(0, 1)) stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real) loss = loss + stat_weight * stat_loss opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) opt.step() if ema is not None: ema.update(model) if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) with open(log_path, "a", encoding="utf-8") as f: f.write( "%d,%d,%.6f,%.6f,%.6f\n" % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) ) total_step += 1 if total_step % int(config["ckpt_every"]) == 0: ckpt = { "model": model.state_dict(), "optim": opt.state_dict(), "config": config, "step": total_step, } if ema is not None: ckpt["ema"] = ema.state_dict() if temporal_model is not None: ckpt["temporal"] = temporal_model.state_dict() torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) if ema is not None: torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) if temporal_model is not None: torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt")) if __name__ == "__main__": main()