#!/usr/bin/env python3 """Training stub for hybrid diffusion on HAI 21.03. This is a scaffold that shows data loading, forward noising, and loss setup. """ import csv import gzip import json import math from pathlib import Path from typing import Dict, List, Tuple import torch import torch.nn.functional as F from hybrid_diffusion import ( HybridDiffusionModel, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, ) BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz") SPLIT_PATH = str(BASE_DIR / "feature_split.json") DEVICE = "cpu" TIMESTEPS = 1000 def load_split(path: str) -> Dict[str, List[str]]: with open(path, "r", encoding="ascii") as f: return json.load(f) def iter_rows(path: str): with gzip.open(path, "rt", newline="") as f: reader = csv.DictReader(f) for row in reader: yield row def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]: values = {c: set() for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: v = row[c] values[c].add(v) if i + 1 >= max_rows: break sizes = [len(values[c]) for c in disc_cols] return sizes def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64): cont = [] disc = [] current = [] for row in iter_rows(path): cont_row = [float(row[c]) for c in cont_cols] disc_row = [int(float(row[c])) for c in disc_cols] current.append((cont_row, disc_row)) if len(current) == seq_len: cont.append([r[0] for r in current]) disc.append([r[1] for r in current]) current = [] if len(cont) == batch_size: break x_cont = torch.tensor(cont, dtype=torch.float32) x_disc = torch.tensor(disc, dtype=torch.long) return x_cont, x_disc def main(): split = load_split(SPLIT_PATH) cont_cols = split["continuous"] disc_cols = split["discrete"] vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols) model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols) x_cont = x_cont.to(DEVICE) x_disc = x_disc.to(DEVICE) bsz = x_cont.size(0) t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS) eps_pred, logits = model(x_cont_t, x_disc_t, t) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 for i, logit in enumerate(logits): # flatten target = x_disc[:, :, i] if mask.any(): loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]]) lam = 0.5 loss = lam * loss_cont + (1 - lam) * loss_disc print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss)) if __name__ == "__main__": main()