Clean artifacts and update example pipeline
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -11,11 +12,25 @@ import torch.nn.functional as F
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||
VOCAB_PATH = str(BASE_DIR / "results" / "disc_vocab.json")
|
||||
MODEL_PATH = str(BASE_DIR / "results" / "model.pt")
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def resolve_device(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
if mode == "cpu":
|
||||
return "cpu"
|
||||
if mode == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("device set to cuda but CUDA is not available")
|
||||
return "cuda"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
DEVICE = resolve_device("auto")
|
||||
TIMESTEPS = 200
|
||||
SEQ_LEN = 64
|
||||
BATCH_SIZE = 2
|
||||
@@ -28,8 +43,9 @@ def load_vocab():
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
vocab = load_vocab()
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
Reference in New Issue
Block a user