This commit is contained in:
Mingzhe Yang
2026-02-04 03:53:17 +08:00
parent 2072351c0d
commit 10c0721ee1
6 changed files with 1134 additions and 104 deletions

View File

@@ -108,6 +108,8 @@ def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--out-dir", default=None, help="Override output directory")
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
return parser.parse_args()
@@ -168,6 +170,14 @@ def main():
# 优先使用命令行传入的device参数
if args.device != "auto":
config["device"] = args.device
if args.out_dir:
out_dir = Path(args.out_dir)
if not out_dir.is_absolute():
base = Path(args.config).resolve().parent if args.config else BASE_DIR
out_dir = resolve_path(base, out_dir)
config["out_dir"] = str(out_dir)
if args.seed is not None:
config["seed"] = int(args.seed)
set_seed(int(config["seed"]))