Expand run_all pipeline and docs

This commit is contained in:
2026-01-28 20:17:50 +08:00
parent 39eede92f6
commit f3991cc91e
3 changed files with 47 additions and 6 deletions

View File

@@ -20,14 +20,18 @@ def run(cmd):
def parse_args():
parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> evaluate.")
parser = argparse.ArgumentParser(description="Run full pipeline end-to-end.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--config", default=str(base_dir / "config.json"))
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--reference", default="", help="override reference glob (train*.csv.gz)")
parser.add_argument("--skip-prepare", action="store_true")
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
return parser.parse_args()
@@ -79,14 +83,50 @@ def main():
"--use-ema",
]
)
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
if not args.skip_eval:
ref = cfg.get("data_glob") or cfg.get("data_path") or ""
if ref:
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else:
run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
if not args.skip_postprocess:
cmd = [
sys.executable,
str(base_dir / "postprocess_types.py"),
"--generated",
str(base_dir / "results" / "generated.csv"),
"--config",
str(config_path),
]
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if not args.skip_post_eval:
cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(base_dir / "results" / "generated_post.csv"),
"--out",
"results/eval_post.json",
]
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if not args.skip_diagnostics:
if ref:
run([sys.executable, str(base_dir / "diagnose_ks.py"), "--generated", str(base_dir / "results" / "generated_post.csv"), "--reference", str(ref)])
run([sys.executable, str(base_dir / "filtered_metrics.py"), "--eval", str(base_dir / "results" / "eval_post.json")])
run([sys.executable, str(base_dir / "ranked_ks.py"), "--eval", str(base_dir / "results" / "eval_post.json")])
run([sys.executable, str(base_dir / "program_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
run([sys.executable, str(base_dir / "controller_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
run([sys.executable, str(base_dir / "actuator_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
run([sys.executable, str(base_dir / "pv_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
run([sys.executable, str(base_dir / "aux_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
if __name__ == "__main__":