Files
mask-ddpm/example/train.py
2026-01-26 18:27:41 +08:00

387 lines
14 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Train hybrid diffusion on HAI (configurable runnable example)."""
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
"split_path": BASE_DIR / "feature_split.json",
"stats_path": BASE_DIR / "results" / "cont_stats.json",
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
"out_dir": BASE_DIR / "results",
"device": "auto",
"timesteps": 1000,
"batch_size": 8,
"seq_len": 64,
"epochs": 1,
"max_batches": 50,
"lambda": 0.5,
"lr": 1e-3,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": True,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": False,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0 | v
"cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
}
def load_json(path: str) -> Dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 使用 platform_utils 中的 resolve_device 函数
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
return parser.parse_args()
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
# 如果值是字符串转换为Path对象
if isinstance(config[key], str):
path_str = config[key]
# glob pattern cannot be Path.resolve()'d on Windows
if "*" in path_str or "?" in path_str or "[" in path_str:
config[key] = str((base_dir / Path(path_str)))
continue
path = Path(path_str)
else:
path = config[key]
if not path.is_absolute():
config[key] = str(resolve_path(base_dir, path))
else:
config[key] = str(path)
return config
class EMA:
def __init__(self, model, decay: float):
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.detach().clone()
def update(self, model):
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old = self.shadow[name]
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
def state_dict(self):
return self.shadow
def main():
args = parse_args()
if args.config:
print("using_config", str(Path(args.config).resolve()))
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
# 优先使用命令行传入的device参数
if args.device != "auto":
config["device"] = args.device
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv")
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
with open(log_path, "w", encoding="utf-8") as f:
if use_quantile:
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
else:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
total_step = 0
for epoch in range(int(config["epochs"])):
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
loss_base = (eps_pred - v_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = (loss_base * weights).mean()
else:
loss_cont = loss_base.mean()
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
quantile_loss = 0.0
if q_weight > 0:
warmup = int(config.get("quantile_loss_warmup_steps", 0))
if warmup > 0:
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0.
x_real = x_cont
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
elif cont_target == "v":
v_pred = eps_pred
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
else:
# eps prediction
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
q_clip = float(config.get("quantile_loss_clip", 0.0))
if q_clip > 0:
x_real = torch.clamp(x_real, -q_clip, q_clip)
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
q_diff = q_gen - q_real
if q_delta > 0:
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f:
if use_quantile:
f.write(
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
% (
epoch,
step,
float(loss),
float(loss_cont),
float(loss_disc),
float(quantile_loss),
)
)
else:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if __name__ == "__main__":
main()