update
This commit is contained in:
@@ -57,6 +57,12 @@ def finalize_stats(stats):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated
|
||||||
|
args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split
|
||||||
|
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
|
||||||
|
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
|
||||||
|
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
|
||||||
split = load_json(args.split)
|
split = load_json(args.split)
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
||||||
@@ -78,6 +78,15 @@ def parse_args():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
args.data_path = str(resolve_path(base_dir, args.data_path))
|
||||||
|
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
|
||||||
|
args.split_path = str(resolve_path(base_dir, args.split_path))
|
||||||
|
args.stats_path = str(resolve_path(base_dir, args.stats_path))
|
||||||
|
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
|
||||||
|
args.model_path = str(resolve_path(base_dir, args.model_path))
|
||||||
|
args.out = str(resolve_path(base_dir, args.out))
|
||||||
|
|
||||||
if not os.path.exists(args.model_path):
|
if not os.path.exists(args.model_path):
|
||||||
raise SystemExit("missing model file: %s" % args.model_path)
|
raise SystemExit("missing model file: %s" % args.model_path)
|
||||||
|
|
||||||
@@ -107,6 +116,8 @@ def main():
|
|||||||
cfg = {}
|
cfg = {}
|
||||||
use_condition = False
|
use_condition = False
|
||||||
cond_vocab_size = 0
|
cond_vocab_size = 0
|
||||||
|
if args.config:
|
||||||
|
args.config = str(resolve_path(base_dir, args.config))
|
||||||
if args.config and os.path.exists(args.config):
|
if args.config and os.path.exists(args.config):
|
||||||
with open(args.config, "r", encoding="utf-8") as f:
|
with open(args.config, "r", encoding="utf-8") as f:
|
||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
|
|||||||
@@ -174,6 +174,24 @@ def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
|
|||||||
return (base_path / target_path).resolve()
|
return (base_path / target_path).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
|
||||||
|
"""
|
||||||
|
Resolve target path against base if target is relative.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: base directory
|
||||||
|
target: target path (absolute or relative)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute Path
|
||||||
|
"""
|
||||||
|
base_path = Path(base) if isinstance(base, str) else base
|
||||||
|
target_path = Path(target) if isinstance(target, str) else target
|
||||||
|
if target_path.is_absolute():
|
||||||
|
return target_path
|
||||||
|
return (base_path / target_path).resolve()
|
||||||
|
|
||||||
|
|
||||||
def print_platform_summary():
|
def print_platform_summary():
|
||||||
"""打印平台摘要信息"""
|
"""打印平台摘要信息"""
|
||||||
info = get_platform_info()
|
info = get_platform_info()
|
||||||
@@ -212,4 +230,4 @@ if __name__ == "__main__":
|
|||||||
print("\n路径处理测试:")
|
print("\n路径处理测试:")
|
||||||
test_path = "some/path/to/file.txt"
|
test_path = "some/path/to/file.txt"
|
||||||
print(f" 原始路径: {test_path}")
|
print(f" 原始路径: {test_path}")
|
||||||
print(f" 安全路径: {safe_path(test_path)}")
|
print(f" 安全路径: {safe_path(test_path)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user