update
This commit is contained in:
@@ -123,9 +123,14 @@ def main():
|
|||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||||
if use_condition:
|
if use_condition:
|
||||||
base = Path(cfg.get("data_glob", args.data_glob)).parent
|
cfg_base = Path(args.config).resolve().parent
|
||||||
pat = Path(cfg.get("data_glob", args.data_glob)).name
|
cfg_glob = cfg.get("data_glob", args.data_glob)
|
||||||
|
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
|
||||||
|
base = Path(cfg_glob).parent
|
||||||
|
pat = Path(cfg_glob).name
|
||||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||||
|
if cond_vocab_size <= 0:
|
||||||
|
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
||||||
|
|
||||||
model = HybridDiffusionModel(
|
model = HybridDiffusionModel(
|
||||||
cont_dim=len(cont_cols),
|
cont_dim=len(cont_cols),
|
||||||
|
|||||||
Reference in New Issue
Block a user