update
This commit is contained in:
@@ -123,9 +123,14 @@ def main():
|
||||
cfg = json.load(f)
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
if use_condition:
|
||||
base = Path(cfg.get("data_glob", args.data_glob)).parent
|
||||
pat = Path(cfg.get("data_glob", args.data_glob)).name
|
||||
cfg_base = Path(args.config).resolve().parent
|
||||
cfg_glob = cfg.get("data_glob", args.data_glob)
|
||||
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
|
||||
base = Path(cfg_glob).parent
|
||||
pat = Path(cfg_glob).name
|
||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
|
||||
Reference in New Issue
Block a user