update
This commit is contained in:
@@ -108,6 +108,8 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -168,6 +170,14 @@ def main():
|
||||
# 优先使用命令行传入的device参数
|
||||
if args.device != "auto":
|
||||
config["device"] = args.device
|
||||
if args.out_dir:
|
||||
out_dir = Path(args.out_dir)
|
||||
if not out_dir.is_absolute():
|
||||
base = Path(args.config).resolve().parent if args.config else BASE_DIR
|
||||
out_dir = resolve_path(base, out_dir)
|
||||
config["out_dir"] = str(out_dir)
|
||||
if args.seed is not None:
|
||||
config["seed"] = int(args.seed)
|
||||
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user