This commit is contained in:
2026-01-22 20:57:56 +08:00
parent 518dea58f6
commit 9ba96fe77e

View File

@@ -123,9 +123,14 @@ def main():
cfg = json.load(f) cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition: if use_condition:
base = Path(cfg.get("data_glob", args.data_glob)).parent cfg_base = Path(args.config).resolve().parent
pat = Path(cfg.get("data_glob", args.data_glob)).name cfg_glob = cfg.get("data_glob", args.data_glob)
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
base = Path(cfg_glob).parent
pat = Path(cfg_glob).name
cond_vocab_size = len(sorted(base.glob(pat))) cond_vocab_size = len(sorted(base.glob(pat)))
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),