Files
mask-ddpm/example/train.py

458 lines
18 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Train hybrid diffusion on HAI (configurable runnable example)."""
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
"split_path": BASE_DIR / "feature_split.json",
"stats_path": BASE_DIR / "results" / "cont_stats.json",
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
"out_dir": BASE_DIR / "results",
"device": "auto",
"timesteps": 1000,
"batch_size": 8,
"seq_len": 64,
"epochs": 1,
"max_batches": 50,
"lambda": 0.5,
"lr": 1e-3,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": True,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": False,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"snr_weighted_loss": True,
"snr_gamma": 1.0,
"residual_stat_weight": 0.0,
}
def load_json(path: str) -> Dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 使用 platform_utils 中的 resolve_device 函数
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
return parser.parse_args()
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
# 如果值是字符串转换为Path对象
if isinstance(config[key], str):
path_str = config[key]
# glob pattern cannot be Path.resolve()'d on Windows
if "*" in path_str or "?" in path_str or "[" in path_str:
config[key] = str((base_dir / Path(path_str)))
continue
path = Path(path_str)
else:
path = config[key]
if not path.is_absolute():
config[key] = str(resolve_path(base_dir, path))
else:
config[key] = str(path)
return config
class EMA:
def __init__(self, model, decay: float):
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.detach().clone()
def update(self, model):
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old = self.shadow[name]
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
def state_dict(self):
return self.shadow
def main():
args = parse_args()
if args.config:
print("using_config", str(Path(args.config).resolve()))
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
# 优先使用命令行传入的device参数
if args.device != "auto":
config["device"] = args.device
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
type1_cols = config.get("type1_features", []) or []
type5_cols = config.get("type5_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
if not model_cont_cols:
raise SystemExit("model_cont_cols is empty; check type1/type5 config")
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
use_quantile = bool(config.get("use_quantile_transform", False))
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
backbone_type=str(config.get("backbone_type", "gru")),
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
transformer_nhead=int(config.get("transformer_nhead", 8)),
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
cond_cont_dim=len(type1_cols),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
temporal_model = None
opt_temporal = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt_temporal = torch.optim.Adam(
temporal_model.parameters(),
lr=float(config.get("temporal_lr", config["lr"])),
)
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv")
with open(log_path, "w", encoding="utf-8") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
if temporal_model is not None and opt_temporal is not None:
for epoch in range(int(config.get("temporal_epochs", 1))):
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=False,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
x_cont, _ = batch
x_cont = x_cont.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
x_cont_model = x_cont[:, :, model_idx]
trend, pred_next = temporal_model.forward_teacher(x_cont_model)
temporal_loss = F.mse_loss(pred_next, x_cont_model[:, 1:, :])
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if step % int(config["log_every"]) == 0:
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
total_step = 0
for epoch in range(int(config["epochs"])):
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
x_cont_model = x_cont[:, :, model_idx]
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
trend = None
if temporal_model is not None:
temporal_model.eval()
with torch.no_grad():
trend, _ = temporal_model.forward_teacher(x_cont_model)
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = (loss_base * weights).mean()
else:
loss_cont = loss_base.mean()
if bool(config.get("snr_weighted_loss", False)):
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
gamma = float(config.get("snr_gamma", 1.0))
snr_weight = snr / (snr + gamma)
loss_cont = (loss_cont * snr_weight.mean()).mean()
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
if q_weight > 0:
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
# Quantile loss on residual distribution
x_real = x_cont_resid
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
loss = loss + q_weight * quantile_loss
stat_weight = float(config.get("residual_stat_weight", 0.0))
if stat_weight > 0:
# residual distribution matching (mean/std)
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
x_real = x_cont_resid
mean_real = x_real.mean(dim=(0, 1))
mean_gen = x_gen.mean(dim=(0, 1))
std_real = x_real.std(dim=(0, 1))
std_gen = x_gen.std(dim=(0, 1))
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
loss = loss + stat_weight * stat_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__":
main()