381 lines
14 KiB
Python
Executable File
381 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""Train hybrid diffusion on HAI (configurable runnable example)."""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
from pathlib import Path
|
||
from typing import Dict
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
from data_utils import load_split, windowed_batches
|
||
from hybrid_diffusion import (
|
||
HybridDiffusionModel,
|
||
cosine_beta_schedule,
|
||
q_sample_continuous,
|
||
q_sample_discrete,
|
||
)
|
||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||
|
||
|
||
BASE_DIR = Path(__file__).resolve().parent
|
||
REPO_DIR = BASE_DIR.parent.parent
|
||
|
||
DEFAULTS = {
|
||
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
|
||
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
|
||
"split_path": BASE_DIR / "feature_split.json",
|
||
"stats_path": BASE_DIR / "results" / "cont_stats.json",
|
||
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
|
||
"out_dir": BASE_DIR / "results",
|
||
"device": "auto",
|
||
"timesteps": 1000,
|
||
"batch_size": 8,
|
||
"seq_len": 64,
|
||
"epochs": 1,
|
||
"max_batches": 50,
|
||
"lambda": 0.5,
|
||
"lr": 1e-3,
|
||
"seed": 1337,
|
||
"log_every": 10,
|
||
"ckpt_every": 50,
|
||
"ema_decay": 0.999,
|
||
"use_ema": True,
|
||
"clip_k": 5.0,
|
||
"grad_clip": 1.0,
|
||
"use_condition": True,
|
||
"condition_type": "file_id",
|
||
"cond_dim": 32,
|
||
"use_tanh_eps": False,
|
||
"eps_scale": 1.0,
|
||
"model_time_dim": 128,
|
||
"model_hidden_dim": 512,
|
||
"model_num_layers": 2,
|
||
"model_dropout": 0.1,
|
||
"model_ff_mult": 2,
|
||
"model_pos_dim": 64,
|
||
"model_use_pos_embed": True,
|
||
"disc_mask_scale": 0.9,
|
||
"shuffle_buffer": 256,
|
||
"cont_loss_weighting": "none", # none | inv_std
|
||
"cont_loss_eps": 1e-6,
|
||
"cont_target": "eps", # eps | x0 | v
|
||
"cont_clamp_x0": 0.0,
|
||
"quantile_loss_weight": 0.0,
|
||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||
"quantile_loss_warmup_steps": 200,
|
||
"quantile_loss_clip": 6.0,
|
||
"quantile_loss_huber_delta": 1.0,
|
||
}
|
||
|
||
|
||
def load_json(path: str) -> Dict:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
|
||
|
||
def set_seed(seed: int):
|
||
random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
|
||
# 使用 platform_utils 中的 resolve_device 函数
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||
return parser.parse_args()
|
||
|
||
|
||
def resolve_config_paths(config, base_dir: Path):
|
||
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||
for key in keys:
|
||
if key in config:
|
||
# 如果值是字符串,转换为Path对象
|
||
if isinstance(config[key], str):
|
||
path_str = config[key]
|
||
# glob pattern cannot be Path.resolve()'d on Windows
|
||
if "*" in path_str or "?" in path_str or "[" in path_str:
|
||
config[key] = str((base_dir / Path(path_str)))
|
||
continue
|
||
path = Path(path_str)
|
||
else:
|
||
path = config[key]
|
||
|
||
if not path.is_absolute():
|
||
config[key] = str(resolve_path(base_dir, path))
|
||
else:
|
||
config[key] = str(path)
|
||
return config
|
||
|
||
|
||
class EMA:
|
||
def __init__(self, model, decay: float):
|
||
self.decay = decay
|
||
self.shadow = {}
|
||
for name, param in model.named_parameters():
|
||
if param.requires_grad:
|
||
self.shadow[name] = param.detach().clone()
|
||
|
||
def update(self, model):
|
||
with torch.no_grad():
|
||
for name, param in model.named_parameters():
|
||
if not param.requires_grad:
|
||
continue
|
||
old = self.shadow[name]
|
||
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
|
||
|
||
def state_dict(self):
|
||
return self.shadow
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
if args.config:
|
||
print("using_config", str(Path(args.config).resolve()))
|
||
config = dict(DEFAULTS)
|
||
if args.config:
|
||
cfg_path = Path(args.config).resolve()
|
||
config.update(load_json(str(cfg_path)))
|
||
config = resolve_config_paths(config, cfg_path.parent)
|
||
else:
|
||
config = resolve_config_paths(config, BASE_DIR)
|
||
|
||
# 优先使用命令行传入的device参数
|
||
if args.device != "auto":
|
||
config["device"] = args.device
|
||
|
||
set_seed(int(config["seed"]))
|
||
|
||
split = load_split(config["split_path"])
|
||
time_col = split.get("time_column", "time")
|
||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||
|
||
stats = load_json(config["stats_path"])
|
||
mean = stats["mean"]
|
||
std = stats["std"]
|
||
transforms = stats.get("transform", {})
|
||
raw_std = stats.get("raw_std", std)
|
||
|
||
vocab = load_json(config["vocab_path"])["vocab"]
|
||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||
|
||
data_paths = None
|
||
if "data_glob" in config and config["data_glob"]:
|
||
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
||
if data_paths:
|
||
data_paths = [safe_path(p) for p in data_paths]
|
||
if not data_paths:
|
||
data_paths = [safe_path(config["data_path"])]
|
||
|
||
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
||
cond_vocab_size = len(data_paths) if use_condition else 0
|
||
|
||
device = resolve_device(str(config["device"]))
|
||
print("device", device)
|
||
model = HybridDiffusionModel(
|
||
cont_dim=len(cont_cols),
|
||
disc_vocab_sizes=vocab_sizes,
|
||
time_dim=int(config.get("model_time_dim", 64)),
|
||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||
num_layers=int(config.get("model_num_layers", 1)),
|
||
dropout=float(config.get("model_dropout", 0.0)),
|
||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||
cond_vocab_size=cond_vocab_size,
|
||
cond_dim=int(config.get("cond_dim", 32)),
|
||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||
).to(device)
|
||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||
|
||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||
alphas = 1.0 - betas
|
||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||
|
||
os.makedirs(config["out_dir"], exist_ok=True)
|
||
out_dir = safe_path(config["out_dir"])
|
||
log_path = os.path.join(out_dir, "train_log.csv")
|
||
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
|
||
with open(log_path, "w", encoding="utf-8") as f:
|
||
if use_quantile:
|
||
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
|
||
else:
|
||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||
json.dump(config, f, indent=2)
|
||
|
||
total_step = 0
|
||
for epoch in range(int(config["epochs"])):
|
||
for step, batch in enumerate(
|
||
windowed_batches(
|
||
data_paths,
|
||
cont_cols,
|
||
disc_cols,
|
||
vocab,
|
||
mean,
|
||
std,
|
||
batch_size=int(config["batch_size"]),
|
||
seq_len=int(config["seq_len"]),
|
||
max_batches=int(config["max_batches"]),
|
||
return_file_id=use_condition,
|
||
transforms=transforms,
|
||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||
)
|
||
):
|
||
if use_condition:
|
||
x_cont, x_disc, cond = batch
|
||
cond = cond.to(device)
|
||
else:
|
||
x_cont, x_disc = batch
|
||
cond = None
|
||
x_cont = x_cont.to(device)
|
||
x_disc = x_disc.to(device)
|
||
|
||
bsz = x_cont.size(0)
|
||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||
|
||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||
|
||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||
x_disc_t, mask = q_sample_discrete(
|
||
x_disc,
|
||
t,
|
||
mask_tokens,
|
||
int(config["timesteps"]),
|
||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||
)
|
||
|
||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||
|
||
cont_target = str(config.get("cont_target", "eps"))
|
||
if cont_target == "x0":
|
||
x0_target = x_cont
|
||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||
loss_base = (eps_pred - x0_target) ** 2
|
||
elif cont_target == "v":
|
||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
|
||
loss_base = (eps_pred - v_target) ** 2
|
||
else:
|
||
loss_base = (eps_pred - noise) ** 2
|
||
|
||
if config.get("cont_loss_weighting") == "inv_std":
|
||
weights = torch.tensor(
|
||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
||
device=device,
|
||
dtype=eps_pred.dtype,
|
||
).view(1, 1, -1)
|
||
loss_cont = (loss_base * weights).mean()
|
||
else:
|
||
loss_cont = loss_base.mean()
|
||
loss_disc = 0.0
|
||
loss_disc_count = 0
|
||
for i, logit in enumerate(logits):
|
||
if mask[:, :, i].any():
|
||
loss_disc = loss_disc + F.cross_entropy(
|
||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||
)
|
||
loss_disc_count += 1
|
||
if loss_disc_count > 0:
|
||
loss_disc = loss_disc / loss_disc_count
|
||
|
||
lam = float(config["lambda"])
|
||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||
|
||
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||
quantile_loss = 0.0
|
||
if q_weight > 0:
|
||
warmup = int(config.get("quantile_loss_warmup_steps", 0))
|
||
if warmup > 0:
|
||
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
|
||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||
# Use normalized space for stable quantiles on x0.
|
||
x_real = x_cont
|
||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||
if cont_target == "x0":
|
||
x_gen = eps_pred
|
||
elif cont_target == "v":
|
||
v_pred = eps_pred
|
||
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||
else:
|
||
# eps prediction
|
||
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||
q_clip = float(config.get("quantile_loss_clip", 0.0))
|
||
if q_clip > 0:
|
||
x_real = torch.clamp(x_real, -q_clip, q_clip)
|
||
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
|
||
x_real = x_real.view(-1, x_real.size(-1))
|
||
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
|
||
q_diff = q_gen - q_real
|
||
if q_delta > 0:
|
||
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
|
||
else:
|
||
quantile_loss = torch.mean(torch.abs(q_diff))
|
||
loss = loss + q_weight * quantile_loss
|
||
opt.zero_grad()
|
||
loss.backward()
|
||
if float(config.get("grad_clip", 0.0)) > 0:
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||
opt.step()
|
||
if ema is not None:
|
||
ema.update(model)
|
||
|
||
if step % int(config["log_every"]) == 0:
|
||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||
with open(log_path, "a", encoding="utf-8") as f:
|
||
if use_quantile:
|
||
f.write(
|
||
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
|
||
% (
|
||
epoch,
|
||
step,
|
||
float(loss),
|
||
float(loss_cont),
|
||
float(loss_disc),
|
||
float(quantile_loss),
|
||
)
|
||
)
|
||
else:
|
||
f.write(
|
||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||
)
|
||
|
||
total_step += 1
|
||
if total_step % int(config["ckpt_every"]) == 0:
|
||
ckpt = {
|
||
"model": model.state_dict(),
|
||
"optim": opt.state_dict(),
|
||
"config": config,
|
||
"step": total_step,
|
||
}
|
||
if ema is not None:
|
||
ckpt["ema"] = ema.state_dict()
|
||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||
|
||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||
if ema is not None:
|
||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|