diff --git a/example/evaluate_generated.py b/example/evaluate_generated.py index ee1872f..b5431e8 100644 --- a/example/evaluate_generated.py +++ b/example/evaluate_generated.py @@ -57,6 +57,12 @@ def finalize_stats(stats): def main(): args = parse_args() + base_dir = Path(__file__).resolve().parent + args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated + args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split + args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats + args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab + args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out split = load_json(args.split) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] diff --git a/example/export_samples.py b/example/export_samples.py index 8e05e92..3c09e92 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule -from platform_utils import resolve_device, safe_path, ensure_dir +from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path def load_vocab(path: str) -> Dict[str, Dict[str, int]]: @@ -78,6 +78,15 @@ def parse_args(): def main(): args = parse_args() + base_dir = Path(__file__).resolve().parent + args.data_path = str(resolve_path(base_dir, args.data_path)) + args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else "" + args.split_path = str(resolve_path(base_dir, args.split_path)) + args.stats_path = str(resolve_path(base_dir, args.stats_path)) + args.vocab_path = str(resolve_path(base_dir, args.vocab_path)) + args.model_path = str(resolve_path(base_dir, args.model_path)) + args.out = str(resolve_path(base_dir, args.out)) + if not os.path.exists(args.model_path): raise SystemExit("missing model file: %s" % args.model_path) @@ -107,6 +116,8 @@ def main(): cfg = {} use_condition = False cond_vocab_size = 0 + if args.config: + args.config = str(resolve_path(base_dir, args.config)) if args.config and os.path.exists(args.config): with open(args.config, "r", encoding="utf-8") as f: cfg = json.load(f) diff --git a/example/platform_utils.py b/example/platform_utils.py index 125b3d0..c0f4221 100644 --- a/example/platform_utils.py +++ b/example/platform_utils.py @@ -174,6 +174,24 @@ def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path: return (base_path / target_path).resolve() +def resolve_path(base: Union[str, Path], target: Union[str, Path]) -> Path: + """ + Resolve target path against base if target is relative. + + Args: + base: base directory + target: target path (absolute or relative) + + Returns: + Absolute Path + """ + base_path = Path(base) if isinstance(base, str) else base + target_path = Path(target) if isinstance(target, str) else target + if target_path.is_absolute(): + return target_path + return (base_path / target_path).resolve() + + def print_platform_summary(): """打印平台摘要信息""" info = get_platform_info() @@ -212,4 +230,4 @@ if __name__ == "__main__": print("\n路径处理测试:") test_path = "some/path/to/file.txt" print(f" 原始路径: {test_path}") - print(f" 安全路径: {safe_path(test_path)}") \ No newline at end of file + print(f" 安全路径: {safe_path(test_path)}")