Update example and notes
This commit is contained in:
114
example/train.py
Executable file
114
example/train.py
Executable file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
)
|
||||
|
||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
TIMESTEPS = 1000
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 64
|
||||
EPOCHS = 1
|
||||
MAX_BATCHES = 50
|
||||
LAMBDA = 0.5
|
||||
LR = 1e-3
|
||||
|
||||
|
||||
def load_stats():
|
||||
with open(STATS_PATH, "r", encoding="ascii") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||
return json.load(f)["vocab"]
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
|
||||
stats = load_stats()
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
vocab = load_vocab()
|
||||
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=LR)
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
for step, (x_cont, x_disc) in enumerate(
|
||||
windowed_batches(
|
||||
DATA_PATH,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
mean,
|
||||
std,
|
||||
batch_size=BATCH_SIZE,
|
||||
seq_len=SEQ_LEN,
|
||||
max_batches=MAX_BATCHES,
|
||||
)
|
||||
):
|
||||
x_cont = x_cont.to(DEVICE)
|
||||
x_disc = x_disc.to(DEVICE)
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||
|
||||
loss_cont = F.mse_loss(eps_pred, noise)
|
||||
|
||||
loss_disc = 0.0
|
||||
for i, logit in enumerate(logits):
|
||||
if mask[:, :, i].any():
|
||||
loss_disc = loss_disc + F.cross_entropy(
|
||||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||
)
|
||||
|
||||
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
if step % 10 == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user