Add comprehensive evaluation and ablation runner

This commit is contained in:
MZ YANG
2026-03-25 22:20:43 +08:00
parent f1afd4bf38
commit 957b010ea1
8 changed files with 1730 additions and 30 deletions

View File

@@ -38,6 +38,7 @@ def parse_args():
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-comprehensive-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
@@ -212,14 +213,13 @@ def main():
config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)]
batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary)
if not args.skip_prepare:
run([sys.executable, str(base_dir / "prepare_data.py")])
runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs")
history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv")
summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv")
for config_path in config_paths:
if not args.skip_prepare:
run([sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)])
cfg_base = config_path.parent
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
@@ -351,6 +351,29 @@ def main():
str(base_dir / "results" / "metrics_history.csv"),
]
)
if not args.skip_comprehensive_eval:
run(
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval.json"),
"--device",
args.device,
]
)
if not args.skip_postprocess:
cmd = [
@@ -387,6 +410,29 @@ def main():
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if not args.skip_comprehensive_eval:
run(
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval_post.json"),
"--device",
args.device,
]
)
if not args.skip_diagnostics:
if ref: