Update example and notes
This commit is contained in:
113
example/train_stub.py
Executable file
113
example/train_stub.py
Executable file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Training stub for hybrid diffusion on HAI 21.03.
|
||||
|
||||
This is a scaffold that shows data loading, forward noising, and loss setup.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
)
|
||||
|
||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
DEVICE = "cpu"
|
||||
TIMESTEPS = 1000
|
||||
|
||||
|
||||
def load_split(path: str) -> Dict[str, List[str]]:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def iter_rows(path: str):
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
|
||||
|
||||
def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]:
|
||||
values = {c: set() for c in disc_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in disc_cols:
|
||||
v = row[c]
|
||||
values[c].add(v)
|
||||
if i + 1 >= max_rows:
|
||||
break
|
||||
sizes = [len(values[c]) for c in disc_cols]
|
||||
return sizes
|
||||
|
||||
|
||||
def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64):
|
||||
cont = []
|
||||
disc = []
|
||||
current = []
|
||||
for row in iter_rows(path):
|
||||
cont_row = [float(row[c]) for c in cont_cols]
|
||||
disc_row = [int(float(row[c])) for c in disc_cols]
|
||||
current.append((cont_row, disc_row))
|
||||
if len(current) == seq_len:
|
||||
cont.append([r[0] for r in current])
|
||||
disc.append([r[1] for r in current])
|
||||
current = []
|
||||
if len(cont) == batch_size:
|
||||
break
|
||||
x_cont = torch.tensor(cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(disc, dtype=torch.long)
|
||||
return x_cont, x_disc
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
|
||||
vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols)
|
||||
x_cont = x_cont.to(DEVICE)
|
||||
x_disc = x_disc.to(DEVICE)
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||
|
||||
loss_cont = F.mse_loss(eps_pred, noise)
|
||||
|
||||
loss_disc = 0.0
|
||||
for i, logit in enumerate(logits):
|
||||
# flatten
|
||||
target = x_disc[:, :, i]
|
||||
if mask.any():
|
||||
loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]])
|
||||
|
||||
lam = 0.5
|
||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||
print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user