#!/usr/bin/env python3 """Training stub for hybrid diffusion on HAI 21.03. This is a scaffold that shows data loading, forward noising, and loss setup. """ import csv import gzip import json import math from typing import Dict, List, Tuple import torch import torch.nn.functional as F from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz" SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json" DEVICE = "cpu" TIMESTEPS = 1000 def load_split(path: str) -> Dict[str, List[str]]: with open(path, "r", encoding="ascii") as f: return json.load(f) def iter_rows(path: str): with gzip.open(path, "rt", newline="") as f: reader = csv.DictReader(f) for row in reader: yield row def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]: values = {c: set() for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: v = row[c] values[c].add(v) if i + 1 >= max_rows: break sizes = [len(values[c]) for c in disc_cols] return sizes def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64): cont = [] disc = [] current = [] for row in iter_rows(path): cont_row = [float(row[c]) for c in cont_cols] disc_row = [int(float(row[c])) for c in disc_cols] current.append((cont_row, disc_row)) if len(current) == seq_len: cont.append([r[0] for r in current]) disc.append([r[1] for r in current]) current = [] if len(cont) == batch_size: break x_cont = torch.tensor(cont, dtype=torch.float32) x_disc = torch.tensor(disc, dtype=torch.long) return x_cont, x_disc def main(): split = load_split(SPLIT_PATH) cont_cols = split["continuous"] disc_cols = split["discrete"] vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols) x_cont = x_cont.to(DEVICE) x_disc = x_disc.to(DEVICE) bsz = x_cont.size(0) t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS) eps_pred, logits = model(x_cont_t, x_disc_t, t) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 for i, logit in enumerate(logits): # flatten target = x_disc[:, :, i] if mask.any(): loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]]) lam = 0.5 loss = lam * loss_cont + (1 - lam) * loss_disc print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss)) if __name__ == "__main__": main()