diff --git a/example/export_samples.py b/example/export_samples.py index 0b4d5df..cec7929 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -112,8 +112,6 @@ def main(): int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) transforms = stats.get("transform", {}) - cont_target = str(cfg.get("cont_target", "eps")) - cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -140,6 +138,8 @@ def main(): cond_vocab_size = len(sorted(base.glob(pat))) if cond_vocab_size <= 0: raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) + cont_target = str(cfg.get("cont_target", "eps")) + cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) model = HybridDiffusionModel( cont_dim=len(cont_cols),