Win and linux can run the code
This commit is contained in:
@@ -18,17 +18,18 @@ from hybrid_diffusion import (
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
)
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
|
||||
DEFAULTS = {
|
||||
"data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"),
|
||||
"split_path": str(BASE_DIR / "feature_split.json"),
|
||||
"stats_path": str(BASE_DIR / "results" / "cont_stats.json"),
|
||||
"vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"),
|
||||
"out_dir": str(BASE_DIR / "results"),
|
||||
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
|
||||
"split_path": BASE_DIR / "feature_split.json",
|
||||
"stats_path": BASE_DIR / "results" / "cont_stats.json",
|
||||
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
|
||||
"out_dir": BASE_DIR / "results",
|
||||
"device": "auto",
|
||||
"timesteps": 1000,
|
||||
"batch_size": 8,
|
||||
@@ -44,7 +45,7 @@ DEFAULTS = {
|
||||
|
||||
|
||||
def load_json(path: str) -> Dict:
|
||||
with open(path, "r", encoding="ascii") as f:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@@ -57,22 +58,13 @@ def set_seed(seed: int):
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def resolve_device(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
if mode == "cpu":
|
||||
return "cpu"
|
||||
if mode == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("device set to cuda but CUDA is not available")
|
||||
return "cuda"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -80,9 +72,16 @@ def resolve_config_paths(config, base_dir: Path):
|
||||
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||
for key in keys:
|
||||
if key in config:
|
||||
path = Path(str(config[key]))
|
||||
# 如果值是字符串,转换为Path对象
|
||||
if isinstance(config[key], str):
|
||||
path = Path(config[key])
|
||||
else:
|
||||
path = config[key]
|
||||
|
||||
if not path.is_absolute():
|
||||
config[key] = str((base_dir / path).resolve())
|
||||
else:
|
||||
config[key] = str(path)
|
||||
return config
|
||||
|
||||
|
||||
@@ -96,6 +95,10 @@ def main():
|
||||
else:
|
||||
config = resolve_config_paths(config, BASE_DIR)
|
||||
|
||||
# 优先使用命令行传入的device参数
|
||||
if args.device != "auto":
|
||||
config["device"] = args.device
|
||||
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
split = load_split(config["split_path"])
|
||||
@@ -121,7 +124,7 @@ def main():
|
||||
|
||||
os.makedirs(config["out_dir"], exist_ok=True)
|
||||
log_path = os.path.join(config["out_dir"], "train_log.csv")
|
||||
with open(log_path, "w", encoding="ascii") as f:
|
||||
with open(log_path, "w", encoding="utf-8") as f:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
|
||||
total_step = 0
|
||||
@@ -168,7 +171,7 @@ def main():
|
||||
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
with open(log_path, "a", encoding="ascii") as f:
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||
|
||||
Reference in New Issue
Block a user