Win and linux can run the code

This commit is contained in:
MZ YANG
2026-01-22 17:39:31 +08:00
parent c3f750cd9d
commit f37a8ce179
22 changed files with 32572 additions and 87 deletions

View File

@@ -11,23 +11,14 @@ import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
VOCAB_PATH = str(BASE_DIR / "results" / "disc_vocab.json")
MODEL_PATH = str(BASE_DIR / "results" / "model.pt")
SPLIT_PATH = BASE_DIR / "feature_split.json"
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
MODEL_PATH = BASE_DIR / "results" / "model.pt"
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
# 使用 platform_utils 中的 resolve_device 函数
DEVICE = resolve_device("auto")
@@ -37,12 +28,12 @@ BATCH_SIZE = 2
def load_vocab():
with open(VOCAB_PATH, "r", encoding="ascii") as f:
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
def main():
split = load_split(SPLIT_PATH)
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
@@ -52,8 +43,8 @@ def main():
print("device", DEVICE)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
if os.path.exists(MODEL_PATH):
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
if MODEL_PATH.exists():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval()
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)