Win and linux can run the code
This commit is contained in:
@@ -14,15 +14,16 @@ import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
|
||||
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)["vocab"]
|
||||
|
||||
|
||||
def load_stats(path: str):
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@@ -66,17 +67,7 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_device(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
if mode == "cpu":
|
||||
return "cpu"
|
||||
if mode == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("device set to cuda but CUDA is not available")
|
||||
return "cuda"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
|
||||
def main():
|
||||
@@ -156,7 +147,7 @@ def main():
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
with open(args.out, "w", newline="", encoding="ascii") as f:
|
||||
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=out_cols)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user