update
This commit is contained in:
@@ -14,7 +14,7 @@ import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||
|
||||
|
||||
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
||||
@@ -78,6 +78,15 @@ def parse_args():
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
args.data_path = str(resolve_path(base_dir, args.data_path))
|
||||
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
|
||||
args.split_path = str(resolve_path(base_dir, args.split_path))
|
||||
args.stats_path = str(resolve_path(base_dir, args.stats_path))
|
||||
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
|
||||
args.model_path = str(resolve_path(base_dir, args.model_path))
|
||||
args.out = str(resolve_path(base_dir, args.out))
|
||||
|
||||
if not os.path.exists(args.model_path):
|
||||
raise SystemExit("missing model file: %s" % args.model_path)
|
||||
|
||||
@@ -107,6 +116,8 @@ def main():
|
||||
cfg = {}
|
||||
use_condition = False
|
||||
cond_vocab_size = 0
|
||||
if args.config:
|
||||
args.config = str(resolve_path(base_dir, args.config))
|
||||
if args.config and os.path.exists(args.config):
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
Reference in New Issue
Block a user