115 lines
3.2 KiB
Python
Executable File
115 lines
3.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
|
|
|
|
import json
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from data_utils import load_split, windowed_batches
|
|
from hybrid_diffusion import (
|
|
HybridDiffusionModel,
|
|
cosine_beta_schedule,
|
|
q_sample_continuous,
|
|
q_sample_discrete,
|
|
)
|
|
|
|
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
|
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
|
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
|
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
TIMESTEPS = 1000
|
|
BATCH_SIZE = 8
|
|
SEQ_LEN = 64
|
|
EPOCHS = 1
|
|
MAX_BATCHES = 50
|
|
LAMBDA = 0.5
|
|
LR = 1e-3
|
|
|
|
|
|
def load_stats():
|
|
with open(STATS_PATH, "r", encoding="ascii") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def load_vocab():
|
|
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
|
return json.load(f)["vocab"]
|
|
|
|
|
|
def main():
|
|
split = load_split(SPLIT_PATH)
|
|
cont_cols = split["continuous"]
|
|
disc_cols = split["discrete"]
|
|
|
|
stats = load_stats()
|
|
mean = stats["mean"]
|
|
std = stats["std"]
|
|
vocab = load_vocab()
|
|
|
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
|
|
|
print("device", DEVICE)
|
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
|
opt = torch.optim.Adam(model.parameters(), lr=LR)
|
|
|
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
os.makedirs(OUT_DIR, exist_ok=True)
|
|
|
|
for epoch in range(EPOCHS):
|
|
for step, (x_cont, x_disc) in enumerate(
|
|
windowed_batches(
|
|
DATA_PATH,
|
|
cont_cols,
|
|
disc_cols,
|
|
vocab,
|
|
mean,
|
|
std,
|
|
batch_size=BATCH_SIZE,
|
|
seq_len=SEQ_LEN,
|
|
max_batches=MAX_BATCHES,
|
|
)
|
|
):
|
|
x_cont = x_cont.to(DEVICE)
|
|
x_disc = x_disc.to(DEVICE)
|
|
|
|
bsz = x_cont.size(0)
|
|
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
|
|
|
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
|
|
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
|
|
|
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
|
|
|
loss_cont = F.mse_loss(eps_pred, noise)
|
|
|
|
loss_disc = 0.0
|
|
for i, logit in enumerate(logits):
|
|
if mask[:, :, i].any():
|
|
loss_disc = loss_disc + F.cross_entropy(
|
|
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
|
)
|
|
|
|
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
opt.step()
|
|
|
|
if step % 10 == 0:
|
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
|
|
|
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|