Clean artifacts and update example pipeline

This commit is contained in:
2026-01-22 16:32:51 +08:00
parent c0639386be
commit c3f750cd9d
20 changed files with 651 additions and 30826 deletions

View File

@@ -1,8 +1,12 @@
#!/usr/bin/env python3
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
"""Train hybrid diffusion on HAI (configurable runnable example)."""
import argparse
import json
import os
import random
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
@@ -15,83 +19,140 @@ from hybrid_diffusion import (
q_sample_discrete,
)
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TIMESTEPS = 1000
BATCH_SIZE = 8
SEQ_LEN = 64
EPOCHS = 1
MAX_BATCHES = 50
LAMBDA = 0.5
LR = 1e-3
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"),
"split_path": str(BASE_DIR / "feature_split.json"),
"stats_path": str(BASE_DIR / "results" / "cont_stats.json"),
"vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"),
"out_dir": str(BASE_DIR / "results"),
"device": "auto",
"timesteps": 1000,
"batch_size": 8,
"seq_len": 64,
"epochs": 1,
"max_batches": 50,
"lambda": 0.5,
"lr": 1e-3,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
}
def load_stats():
with open(STATS_PATH, "r", encoding="ascii") as f:
def load_json(path: str) -> Dict:
with open(path, "r", encoding="ascii") as f:
return json.load(f)
def load_vocab():
with open(VOCAB_PATH, "r", encoding="ascii") as f:
return json.load(f)["vocab"]
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
return parser.parse_args()
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
path = Path(str(config[key]))
if not path.is_absolute():
config[key] = str((base_dir / path).resolve())
return config
def main():
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
args = parse_args()
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
stats = load_stats()
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
vocab = load_vocab()
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
print("device", DEVICE)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(config["out_dir"], exist_ok=True)
log_path = os.path.join(config["out_dir"], "train_log.csv")
with open(log_path, "w", encoding="ascii") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
for epoch in range(EPOCHS):
total_step = 0
for epoch in range(int(config["epochs"])):
for step, (x_cont, x_disc) in enumerate(
windowed_batches(
DATA_PATH,
config["data_path"],
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=BATCH_SIZE,
seq_len=SEQ_LEN,
max_batches=MAX_BATCHES,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
)
):
x_cont = x_cont.to(DEVICE)
x_disc = x_disc.to(DEVICE)
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
bsz = x_cont.size(0)
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
eps_pred, logits = model(x_cont_t, x_disc_t, t)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
@@ -99,15 +160,31 @@ def main():
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
opt.zero_grad()
loss.backward()
opt.step()
if step % 10 == 0:
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="ascii") as f:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
}
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
if __name__ == "__main__":