Files
mask-ddpm/example/train.py

192 lines
6.0 KiB
Python
Executable File

#!/usr/bin/env python3
"""Train hybrid diffusion on HAI (configurable runnable example)."""
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"),
"split_path": str(BASE_DIR / "feature_split.json"),
"stats_path": str(BASE_DIR / "results" / "cont_stats.json"),
"vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"),
"out_dir": str(BASE_DIR / "results"),
"device": "auto",
"timesteps": 1000,
"batch_size": 8,
"seq_len": 64,
"epochs": 1,
"max_batches": 50,
"lambda": 0.5,
"lr": 1e-3,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
}
def load_json(path: str) -> Dict:
with open(path, "r", encoding="ascii") as f:
return json.load(f)
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
return parser.parse_args()
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
path = Path(str(config[key]))
if not path.is_absolute():
config[key] = str((base_dir / path).resolve())
return config
def main():
args = parse_args()
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
log_path = os.path.join(config["out_dir"], "train_log.csv")
with open(log_path, "w", encoding="ascii") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
total_step = 0
for epoch in range(int(config["epochs"])):
for step, (x_cont, x_disc) in enumerate(
windowed_batches(
config["data_path"],
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
)
):
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
eps_pred, logits = model(x_cont_t, x_disc_t, t)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
opt.zero_grad()
loss.backward()
opt.step()
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="ascii") as f:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
}
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
if __name__ == "__main__":
main()