Clean artifacts and update example pipeline
This commit is contained in:
@@ -8,6 +8,7 @@ import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -20,8 +21,10 @@ from hybrid_diffusion import (
|
||||
q_sample_discrete,
|
||||
)
|
||||
|
||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||
DEVICE = "cpu"
|
||||
TIMESTEPS = 1000
|
||||
|
||||
|
||||
Reference in New Issue
Block a user