#!/usr/bin/env python3 """Train hybrid diffusion on HAI 21.03 (minimal runnable example).""" import json import os import torch import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz" SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json" STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json" VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json" OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TIMESTEPS = 1000 BATCH_SIZE = 8 SEQ_LEN = 64 EPOCHS = 1 MAX_BATCHES = 50 LAMBDA = 0.5 LR = 1e-3 def load_stats(): with open(STATS_PATH, "r", encoding="ascii") as f: return json.load(f) def load_vocab(): with open(VOCAB_PATH, "r", encoding="ascii") as f: return json.load(f)["vocab"] def main(): split = load_split(SPLIT_PATH) cont_cols = split["continuous"] disc_cols = split["discrete"] stats = load_stats() mean = stats["mean"] std = stats["std"] vocab = load_vocab() vocab_sizes = [len(vocab[c]) for c in disc_cols] print("device", DEVICE) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) opt = torch.optim.Adam(model.parameters(), lr=LR) betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(OUT_DIR, exist_ok=True) for epoch in range(EPOCHS): for step, (x_cont, x_disc) in enumerate( windowed_batches( DATA_PATH, cont_cols, disc_cols, vocab, mean, std, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, max_batches=MAX_BATCHES, ) ): x_cont = x_cont.to(DEVICE) x_disc = x_disc.to(DEVICE) bsz = x_cont.size(0) t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS) eps_pred, logits = model(x_cont_t, x_disc_t, t) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc opt.zero_grad() loss.backward() opt.step() if step % 10 == 0: print("epoch", epoch, "step", step, "loss", float(loss)) torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt")) if __name__ == "__main__": main()