Update example and notes

This commit is contained in:
2026-01-09 02:14:20 +08:00
parent 200bdf6136
commit c0639386be
18 changed files with 31656 additions and 0 deletions

113
example/train_stub.py Executable file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""Training stub for hybrid diffusion on HAI 21.03.
This is a scaffold that shows data loading, forward noising, and loss setup.
"""
import csv
import gzip
import json
import math
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from hybrid_diffusion import (
HybridDiffusionModel,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
DEVICE = "cpu"
TIMESTEPS = 1000
def load_split(path: str) -> Dict[str, List[str]]:
with open(path, "r", encoding="ascii") as f:
return json.load(f)
def iter_rows(path: str):
with gzip.open(path, "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]:
values = {c: set() for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
v = row[c]
values[c].add(v)
if i + 1 >= max_rows:
break
sizes = [len(values[c]) for c in disc_cols]
return sizes
def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64):
cont = []
disc = []
current = []
for row in iter_rows(path):
cont_row = [float(row[c]) for c in cont_cols]
disc_row = [int(float(row[c])) for c in disc_cols]
current.append((cont_row, disc_row))
if len(current) == seq_len:
cont.append([r[0] for r in current])
disc.append([r[1] for r in current])
current = []
if len(cont) == batch_size:
break
x_cont = torch.tensor(cont, dtype=torch.float32)
x_disc = torch.tensor(disc, dtype=torch.long)
return x_cont, x_disc
def main():
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols)
x_cont = x_cont.to(DEVICE)
x_disc = x_disc.to(DEVICE)
bsz = x_cont.size(0)
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
eps_pred, logits = model(x_cont_t, x_disc_t, t)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
for i, logit in enumerate(logits):
# flatten
target = x_disc[:, :, i]
if mask.any():
loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]])
lam = 0.5
loss = lam * loss_cont + (1 - lam) * loss_disc
print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss))
if __name__ == "__main__":
main()