Win and linux can run the code

This commit is contained in:
MZ YANG
2026-01-22 17:39:31 +08:00
parent c3f750cd9d
commit f37a8ce179
22 changed files with 32572 additions and 87 deletions

View File

@@ -20,22 +20,23 @@ from hybrid_diffusion import (
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
DEVICE = "cpu"
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
SPLIT_PATH = BASE_DIR / "feature_split.json"
DEVICE = resolve_device("auto")
TIMESTEPS = 1000
def load_split(path: str) -> Dict[str, List[str]]:
with open(path, "r", encoding="ascii") as f:
with open(str(path), "r", encoding="utf-8") as f:
return json.load(f)
def iter_rows(path: str):
with gzip.open(path, "rt", newline="") as f:
with gzip.open(str(path), "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
@@ -73,18 +74,18 @@ def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size
def main():
split = load_split(SPLIT_PATH)
split = load_split(str(SPLIT_PATH))
cont_cols = split["continuous"]
disc_cols = split["discrete"]
vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols)
vocab_sizes = build_vocab_sizes(str(DATA_PATH), disc_cols)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols)
x_cont, x_disc = load_batch(str(DATA_PATH), cont_cols, disc_cols)
x_cont = x_cont.to(DEVICE)
x_disc = x_disc.to(DEVICE)