Win and linux can run the code
This commit is contained in:
@@ -11,23 +11,14 @@ import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||
VOCAB_PATH = str(BASE_DIR / "results" / "disc_vocab.json")
|
||||
MODEL_PATH = str(BASE_DIR / "results" / "model.pt")
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
||||
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
||||
|
||||
def resolve_device(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
if mode == "cpu":
|
||||
return "cpu"
|
||||
if mode == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("device set to cuda but CUDA is not available")
|
||||
return "cuda"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
|
||||
DEVICE = resolve_device("auto")
|
||||
@@ -37,12 +28,12 @@ BATCH_SIZE = 2
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
|
||||
return json.load(f)["vocab"]
|
||||
|
||||
|
||||
def main():
|
||||
split = load_split(SPLIT_PATH)
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
@@ -52,8 +43,8 @@ def main():
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
|
||||
if MODEL_PATH.exists():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user