118 lines
3.4 KiB
Python
Executable File
118 lines
3.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Training stub for hybrid diffusion on HAI 21.03.
|
|
|
|
This is a scaffold that shows data loading, forward noising, and loss setup.
|
|
"""
|
|
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from hybrid_diffusion import (
|
|
HybridDiffusionModel,
|
|
cosine_beta_schedule,
|
|
q_sample_continuous,
|
|
q_sample_discrete,
|
|
)
|
|
from platform_utils import resolve_device, safe_path, ensure_dir
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
REPO_DIR = BASE_DIR.parent.parent
|
|
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
|
|
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
|
DEVICE = resolve_device("auto")
|
|
TIMESTEPS = 1000
|
|
|
|
|
|
def load_split(path: str) -> Dict[str, List[str]]:
|
|
with open(str(path), "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def iter_rows(path: str):
|
|
with gzip.open(str(path), "rt", newline="") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
yield row
|
|
|
|
|
|
def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]:
|
|
values = {c: set() for c in disc_cols}
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in disc_cols:
|
|
v = row[c]
|
|
values[c].add(v)
|
|
if i + 1 >= max_rows:
|
|
break
|
|
sizes = [len(values[c]) for c in disc_cols]
|
|
return sizes
|
|
|
|
|
|
def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64):
|
|
cont = []
|
|
disc = []
|
|
current = []
|
|
for row in iter_rows(path):
|
|
cont_row = [float(row[c]) for c in cont_cols]
|
|
disc_row = [int(float(row[c])) for c in disc_cols]
|
|
current.append((cont_row, disc_row))
|
|
if len(current) == seq_len:
|
|
cont.append([r[0] for r in current])
|
|
disc.append([r[1] for r in current])
|
|
current = []
|
|
if len(cont) == batch_size:
|
|
break
|
|
x_cont = torch.tensor(cont, dtype=torch.float32)
|
|
x_disc = torch.tensor(disc, dtype=torch.long)
|
|
return x_cont, x_disc
|
|
|
|
|
|
def main():
|
|
split = load_split(str(SPLIT_PATH))
|
|
cont_cols = split["continuous"]
|
|
disc_cols = split["discrete"]
|
|
|
|
vocab_sizes = build_vocab_sizes(str(DATA_PATH), disc_cols)
|
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
|
|
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
x_cont, x_disc = load_batch(str(DATA_PATH), cont_cols, disc_cols)
|
|
x_cont = x_cont.to(DEVICE)
|
|
x_disc = x_disc.to(DEVICE)
|
|
|
|
bsz = x_cont.size(0)
|
|
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
|
|
|
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
|
|
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
|
|
|
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
|
|
|
loss_cont = F.mse_loss(eps_pred, noise)
|
|
|
|
loss_disc = 0.0
|
|
for i, logit in enumerate(logits):
|
|
# flatten
|
|
target = x_disc[:, :, i]
|
|
if mask.any():
|
|
loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]])
|
|
|
|
lam = 0.5
|
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
|
print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|