Files
mask-ddpm/example/train_stub.py
2026-01-22 17:39:31 +08:00

118 lines
3.4 KiB
Python
Executable File

#!/usr/bin/env python3
"""Training stub for hybrid diffusion on HAI 21.03.
This is a scaffold that shows data loading, forward noising, and loss setup.
"""
import csv
import gzip
import json
import math
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from hybrid_diffusion import (
HybridDiffusionModel,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
SPLIT_PATH = BASE_DIR / "feature_split.json"
DEVICE = resolve_device("auto")
TIMESTEPS = 1000
def load_split(path: str) -> Dict[str, List[str]]:
with open(str(path), "r", encoding="utf-8") as f:
return json.load(f)
def iter_rows(path: str):
with gzip.open(str(path), "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]:
values = {c: set() for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
v = row[c]
values[c].add(v)
if i + 1 >= max_rows:
break
sizes = [len(values[c]) for c in disc_cols]
return sizes
def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64):
cont = []
disc = []
current = []
for row in iter_rows(path):
cont_row = [float(row[c]) for c in cont_cols]
disc_row = [int(float(row[c])) for c in disc_cols]
current.append((cont_row, disc_row))
if len(current) == seq_len:
cont.append([r[0] for r in current])
disc.append([r[1] for r in current])
current = []
if len(cont) == batch_size:
break
x_cont = torch.tensor(cont, dtype=torch.float32)
x_disc = torch.tensor(disc, dtype=torch.long)
return x_cont, x_disc
def main():
split = load_split(str(SPLIT_PATH))
cont_cols = split["continuous"]
disc_cols = split["discrete"]
vocab_sizes = build_vocab_sizes(str(DATA_PATH), disc_cols)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont, x_disc = load_batch(str(DATA_PATH), cont_cols, disc_cols)
x_cont = x_cont.to(DEVICE)
x_disc = x_disc.to(DEVICE)
bsz = x_cont.size(0)
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
eps_pred, logits = model(x_cont_t, x_disc_t, t)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
for i, logit in enumerate(logits):
# flatten
target = x_disc[:, :, i]
if mask.any():
loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]])
lam = 0.5
loss = lam * loss_cont + (1 - lam) * loss_disc
print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss))
if __name__ == "__main__":
main()