Win and linux can run the code
This commit is contained in:
@@ -20,22 +20,23 @@ from hybrid_diffusion import (
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
)
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||
DEVICE = "cpu"
|
||||
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
DEVICE = resolve_device("auto")
|
||||
TIMESTEPS = 1000
|
||||
|
||||
|
||||
def load_split(path: str) -> Dict[str, List[str]]:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
with open(str(path), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def iter_rows(path: str):
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
with gzip.open(str(path), "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
@@ -73,18 +74,18 @@ def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
cont_cols = split["continuous"]
|
||||
disc_cols = split["discrete"]
|
||||
|
||||
vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols)
|
||||
vocab_sizes = build_vocab_sizes(str(DATA_PATH), disc_cols)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols)
|
||||
x_cont, x_disc = load_batch(str(DATA_PATH), cont_cols, disc_cols)
|
||||
x_cont = x_cont.to(DEVICE)
|
||||
x_disc = x_disc.to(DEVICE)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user