Files
mask-ddpm/example/train.py
2026-01-09 02:14:20 +08:00

115 lines
3.2 KiB
Python
Executable File

#!/usr/bin/env python3
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
import json
import os
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TIMESTEPS = 1000
BATCH_SIZE = 8
SEQ_LEN = 64
EPOCHS = 1
MAX_BATCHES = 50
LAMBDA = 0.5
LR = 1e-3
def load_stats():
with open(STATS_PATH, "r", encoding="ascii") as f:
return json.load(f)
def load_vocab():
with open(VOCAB_PATH, "r", encoding="ascii") as f:
return json.load(f)["vocab"]
def main():
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
stats = load_stats()
mean = stats["mean"]
std = stats["std"]
vocab = load_vocab()
vocab_sizes = [len(vocab[c]) for c in disc_cols]
print("device", DEVICE)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(OUT_DIR, exist_ok=True)
for epoch in range(EPOCHS):
for step, (x_cont, x_disc) in enumerate(
windowed_batches(
DATA_PATH,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=BATCH_SIZE,
seq_len=SEQ_LEN,
max_batches=MAX_BATCHES,
)
):
x_cont = x_cont.to(DEVICE)
x_disc = x_disc.to(DEVICE)
bsz = x_cont.size(0)
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
eps_pred, logits = model(x_cont_t, x_disc_t, t)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
opt.zero_grad()
loss.backward()
opt.step()
if step % 10 == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
if __name__ == "__main__":
main()