Clean artifacts and update example pipeline

This commit is contained in:
2026-01-22 16:32:51 +08:00
parent c0639386be
commit c3f750cd9d
20 changed files with 651 additions and 30826 deletions

View File

@@ -4,6 +4,7 @@
import json
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
@@ -11,11 +12,25 @@ import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
BASE_DIR = Path(__file__).resolve().parent
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
VOCAB_PATH = str(BASE_DIR / "results" / "disc_vocab.json")
MODEL_PATH = str(BASE_DIR / "results" / "model.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
DEVICE = resolve_device("auto")
TIMESTEPS = 200
SEQ_LEN = 64
BATCH_SIZE = 2
@@ -28,8 +43,9 @@ def load_vocab():
def main():
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
vocab = load_vocab()
vocab_sizes = [len(vocab[c]) for c in disc_cols]