Clean artifacts and update example pipeline
This commit is contained in:
163
example/train.py
163
example/train.py
@@ -1,8 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
|
||||
"""Train hybrid diffusion on HAI (configurable runnable example)."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -15,83 +19,140 @@ from hybrid_diffusion import (
|
||||
q_sample_discrete,
|
||||
)
|
||||
|
||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
TIMESTEPS = 1000
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 64
|
||||
EPOCHS = 1
|
||||
MAX_BATCHES = 50
|
||||
LAMBDA = 0.5
|
||||
LR = 1e-3
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
|
||||
DEFAULTS = {
|
||||
"data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"),
|
||||
"split_path": str(BASE_DIR / "feature_split.json"),
|
||||
"stats_path": str(BASE_DIR / "results" / "cont_stats.json"),
|
||||
"vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"),
|
||||
"out_dir": str(BASE_DIR / "results"),
|
||||
"device": "auto",
|
||||
"timesteps": 1000,
|
||||
"batch_size": 8,
|
||||
"seq_len": 64,
|
||||
"epochs": 1,
|
||||
"max_batches": 50,
|
||||
"lambda": 0.5,
|
||||
"lr": 1e-3,
|
||||
"seed": 1337,
|
||||
"log_every": 10,
|
||||
"ckpt_every": 50,
|
||||
}
|
||||
|
||||
|
||||
def load_stats():
|
||||
with open(STATS_PATH, "r", encoding="ascii") as f:
|
||||
def load_json(path: str) -> Dict:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||
return json.load(f)["vocab"]
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def resolve_device(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
if mode == "cpu":
|
||||
return "cpu"
|
||||
if mode == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("device set to cuda but CUDA is not available")
|
||||
return "cuda"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_config_paths(config, base_dir: Path):
|
||||
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||
for key in keys:
|
||||
if key in config:
|
||||
path = Path(str(config[key]))
|
||||
if not path.is_absolute():
|
||||
config[key] = str((base_dir / path).resolve())
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
args = parse_args()
|
||||
config = dict(DEFAULTS)
|
||||
if args.config:
|
||||
cfg_path = Path(args.config).resolve()
|
||||
config.update(load_json(str(cfg_path)))
|
||||
config = resolve_config_paths(config, cfg_path.parent)
|
||||
else:
|
||||
config = resolve_config_paths(config, BASE_DIR)
|
||||
|
||||
stats = load_stats()
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
split = load_split(config["split_path"])
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
vocab = load_vocab()
|
||||
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=LR)
|
||||
device = resolve_device(str(config["device"]))
|
||||
print("device", device)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
os.makedirs(config["out_dir"], exist_ok=True)
|
||||
log_path = os.path.join(config["out_dir"], "train_log.csv")
|
||||
with open(log_path, "w", encoding="ascii") as f:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
total_step = 0
|
||||
for epoch in range(int(config["epochs"])):
|
||||
for step, (x_cont, x_disc) in enumerate(
|
||||
windowed_batches(
|
||||
DATA_PATH,
|
||||
config["data_path"],
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
mean,
|
||||
std,
|
||||
batch_size=BATCH_SIZE,
|
||||
seq_len=SEQ_LEN,
|
||||
max_batches=MAX_BATCHES,
|
||||
batch_size=int(config["batch_size"]),
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
)
|
||||
):
|
||||
x_cont = x_cont.to(DEVICE)
|
||||
x_disc = x_disc.to(DEVICE)
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||
|
||||
loss_cont = F.mse_loss(eps_pred, noise)
|
||||
|
||||
loss_disc = 0.0
|
||||
for i, logit in enumerate(logits):
|
||||
if mask[:, :, i].any():
|
||||
@@ -99,15 +160,31 @@ def main():
|
||||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||
)
|
||||
|
||||
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
|
||||
lam = float(config["lambda"])
|
||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
if step % 10 == 0:
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
with open(log_path, "a", encoding="ascii") as f:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||
)
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
|
||||
total_step += 1
|
||||
if total_step % int(config["ckpt_every"]) == 0:
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optim": opt.state_dict(),
|
||||
"config": config,
|
||||
"step": total_step,
|
||||
}
|
||||
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user