Files
mask-ddpm/example/train.py
2026-01-26 19:00:16 +08:00

426 lines
16 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Train hybrid diffusion on HAI (configurable runnable example)."""
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
"split_path": BASE_DIR / "feature_split.json",
"stats_path": BASE_DIR / "results" / "cont_stats.json",
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
"out_dir": BASE_DIR / "results",
"device": "auto",
"timesteps": 1000,
"batch_size": 8,
"seq_len": 64,
"epochs": 1,
"max_batches": 50,
"lambda": 0.5,
"lr": 1e-3,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": True,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": False,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0 | v
"cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
}
def load_json(path: str) -> Dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 使用 platform_utils 中的 resolve_device 函数
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
return parser.parse_args()
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
# 如果值是字符串转换为Path对象
if isinstance(config[key], str):
path_str = config[key]
# glob pattern cannot be Path.resolve()'d on Windows
if "*" in path_str or "?" in path_str or "[" in path_str:
config[key] = str((base_dir / Path(path_str)))
continue
path = Path(path_str)
else:
path = config[key]
if not path.is_absolute():
config[key] = str(resolve_path(base_dir, path))
else:
config[key] = str(path)
return config
class EMA:
def __init__(self, model, decay: float):
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.detach().clone()
def update(self, model):
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old = self.shadow[name]
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
def state_dict(self):
return self.shadow
def main():
args = parse_args()
if args.config:
print("using_config", str(Path(args.config).resolve()))
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
# 优先使用命令行传入的device参数
if args.device != "auto":
config["device"] = args.device
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
temporal_model = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
if temporal_model is not None:
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
else:
opt_temporal = None
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv")
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
with open(log_path, "w", encoding="utf-8") as f:
if use_quantile:
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
else:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
total_step = 0
for epoch in range(int(config["epochs"])):
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
temporal_loss = None
x_cont_resid = x_cont
trend = None
if temporal_model is not None:
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
trend = trend.detach()
x_cont_resid = x_cont - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
loss_base = (eps_pred - v_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = (loss_base * weights).mean()
else:
loss_cont = loss_base.mean()
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
quantile_loss = 0.0
if q_weight > 0:
warmup = int(config.get("quantile_loss_warmup_steps", 0))
if warmup > 0:
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0.
x_real = x_cont_resid
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
elif cont_target == "v":
v_pred = eps_pred
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
else:
# eps prediction
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
q_clip = float(config.get("quantile_loss_clip", 0.0))
if q_clip > 0:
x_real = torch.clamp(x_real, -q_clip, q_clip)
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
q_diff = q_gen - q_real
if q_delta > 0:
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if opt_temporal is not None:
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f:
if use_quantile:
f.write(
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
% (
epoch,
step,
float(loss),
float(loss_cont),
float(loss_disc),
float(quantile_loss),
)
)
else:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__":
main()