#!/usr/bin/env python3 """One-command pipeline runner with config-driven paths.""" import argparse import csv import json import math import subprocess import sys from pathlib import Path from platform_utils import safe_path, is_windows def run(cmd): print("running:", " ".join(cmd)) cmd = [safe_path(arg) for arg in cmd] if is_windows(): subprocess.run(cmd, check=True, shell=False) else: subprocess.run(cmd, check=True) def parse_args(): parser = argparse.ArgumentParser(description="Run full pipeline end-to-end.") base_dir = Path(__file__).resolve().parent parser.add_argument("--config", default=str(base_dir / "config.json")) parser.add_argument("--configs", default="", help="Comma-separated configs or globs for batch runs") parser.add_argument("--seeds", default="", help="Comma-separated seeds for batch runs") parser.add_argument("--repeat", type=int, default=1, help="Repeat each config with different seeds") parser.add_argument("--runs-root", default="", help="Root directory for per-run artifacts (batch)") parser.add_argument("--benchmark-history", default="", help="CSV path for batch history output") parser.add_argument("--benchmark-summary", default="", help="CSV path for batch summary output") parser.add_argument("--name-prefix", default="", help="Prefix for batch run directory names") parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--reference", default="", help="override reference glob (train*.csv.gz)") parser.add_argument("--skip-prepare", action="store_true") parser.add_argument("--skip-train", action="store_true") parser.add_argument("--skip-export", action="store_true") parser.add_argument("--skip-eval", action="store_true") parser.add_argument("--skip-postprocess", action="store_true") parser.add_argument("--skip-post-eval", action="store_true") parser.add_argument("--skip-diagnostics", action="store_true") return parser.parse_args() def resolve_config_path(base_dir: Path, cfg_arg: str) -> Path: p = Path(cfg_arg) if p.is_absolute(): if p.exists(): return p.resolve() raise SystemExit(f"config not found: {p}") repo_dir = base_dir.parent candidates = [p, base_dir / p, repo_dir / p] if p.parts and p.parts[0] == "example": trimmed = Path(*p.parts[1:]) if len(p.parts) > 1 else Path() if str(trimmed): candidates.extend([base_dir / trimmed, repo_dir / trimmed]) for c in candidates: if c.exists(): return c.resolve() tried = "\n".join(str(c) for c in candidates) raise SystemExit(f"config not found: {cfg_arg}\ntried:\n{tried}") def resolve_like(base: Path, value: str) -> str: if not value: return "" p = Path(value) if p.is_absolute(): return str(p) s = str(value) if any(ch in s for ch in ["*", "?", "["]): return str(base / p) return str((base / p).resolve()) def expand_config_args(base_dir: Path, arg: str): if not arg: return [] repo_dir = base_dir.parent tokens = [t.strip() for t in arg.split(",") if t.strip()] out = [] for t in tokens: if any(ch in t for ch in ["*", "?", "["]): p = Path(t) if p.is_absolute(): base = p.parent pat = p.name out.extend(sorted(base.glob(pat))) else: candidates = [base_dir / p, repo_dir / p, p] matched = False for c in candidates: base = c.parent pat = c.name matches = sorted(base.glob(pat)) if matches: out.extend(matches) matched = True break if not matched: raise SystemExit(f"no configs matched glob: {t}") else: out.append(resolve_config_path(base_dir, t)) seen = set() uniq = [] for p in out: rp = str(Path(p).resolve()) if rp in seen: continue seen.add(rp) uniq.append(Path(rp)) return uniq def parse_seeds(arg: str): if not arg: return [] out = [] for part in [p.strip() for p in arg.split(",") if p.strip()]: out.append(int(part)) return out def compute_summary(history_path: Path, out_path: Path): if not history_path.exists(): return rows = [] with history_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for r in reader: rows.append(r) if not rows: return def to_float(v): try: return float(v) except Exception: return None grouped = {} for r in rows: cfg = r.get("config", "") or "" grouped.setdefault(cfg, []).append(r) def mean_std(vals): xs = [x for x in vals if x is not None] if not xs: return None, None mu = sum(xs) / len(xs) if len(xs) <= 1: return mu, 0.0 var = sum((x - mu) ** 2 for x in xs) / (len(xs) - 1) return mu, math.sqrt(var) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8", newline="") as f: fieldnames = [ "config", "n_runs", "avg_ks_mean", "avg_ks_std", "avg_jsd_mean", "avg_jsd_std", "avg_lag1_diff_mean", "avg_lag1_diff_std", "best_run_name", "best_avg_ks", ] writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for cfg, rs in sorted(grouped.items(), key=lambda kv: kv[0]): kss = [to_float(x.get("avg_ks")) for x in rs] jsds = [to_float(x.get("avg_jsd")) for x in rs] lags = [to_float(x.get("avg_lag1_diff")) for x in rs] ks_mu, ks_sd = mean_std(kss) jsd_mu, jsd_sd = mean_std(jsds) lag_mu, lag_sd = mean_std(lags) best = None for r in rs: ks = to_float(r.get("avg_ks")) if ks is None: continue if best is None or ks < best[0]: best = (ks, r.get("run_name", "")) writer.writerow( { "config": cfg, "n_runs": len(rs), "avg_ks_mean": ks_mu, "avg_ks_std": ks_sd, "avg_jsd_mean": jsd_mu, "avg_jsd_std": jsd_sd, "avg_lag1_diff_mean": lag_mu, "avg_lag1_diff_std": lag_sd, "best_run_name": "" if best is None else best[1], "best_avg_ks": None if best is None else best[0], } ) def main(): args = parse_args() base_dir = Path(__file__).resolve().parent seed_list = parse_seeds(args.seeds) config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)] batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary) if not args.skip_prepare: run([sys.executable, str(base_dir / "prepare_data.py")]) runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs") history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv") summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv") for config_path in config_paths: cfg_base = config_path.parent with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) timesteps = cfg.get("timesteps", 200) seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) clip_k = cfg.get("clip_k", 5.0) seeds = seed_list if not seeds: base_seed = int(cfg.get("seed", 1337)) if args.repeat and args.repeat > 1: seeds = [base_seed + i for i in range(int(args.repeat))] else: seeds = [base_seed] for seed in seeds: run_dir = base_dir / "results" if not batch_mode else (runs_root / f"{args.name_prefix}{config_path.stem}__seed{seed}") run_dir.mkdir(parents=True, exist_ok=True) data_path = resolve_like(cfg_base, str(cfg.get("data_path", ""))) data_glob = resolve_like(cfg_base, str(cfg.get("data_glob", ""))) split_path = resolve_like(cfg_base, str(cfg.get("split_path", ""))) or str(base_dir / "feature_split.json") stats_path = resolve_like(cfg_base, str(cfg.get("stats_path", ""))) or str(base_dir / "results" / "cont_stats.json") vocab_path = resolve_like(cfg_base, str(cfg.get("vocab_path", ""))) or str(base_dir / "results" / "disc_vocab.json") ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or "" ref = resolve_like(cfg_base, str(ref)) if ref else "" if not args.skip_train: run( [ sys.executable, str(base_dir / "train.py"), "--config", str(config_path), "--device", args.device, "--out-dir", str(run_dir), "--seed", str(seed), ] ) config_used = run_dir / "config_used.json" cfg_for_steps = config_used if config_used.exists() else config_path if not args.skip_export: run( [ sys.executable, str(base_dir / "export_samples.py"), "--include-time", "--device", args.device, "--config", str(cfg_for_steps), "--data-path", str(data_path), "--data-glob", str(data_glob), "--split-path", str(split_path), "--stats-path", str(stats_path), "--vocab-path", str(vocab_path), "--model-path", str(run_dir / "model.pt"), "--out", str(run_dir / "generated.csv"), "--timesteps", str(timesteps), "--seq-len", str(seq_len), "--batch-size", str(batch_size), "--clip-k", str(clip_k), "--use-ema", ] ) if not args.skip_eval: cmd = [ sys.executable, str(base_dir / "evaluate_generated.py"), "--generated", str(run_dir / "generated.csv"), "--split", str(split_path), "--stats", str(stats_path), "--vocab", str(vocab_path), "--out", str(run_dir / "eval.json"), ] if ref: cmd += ["--reference", str(ref)] run(cmd) if batch_mode: run( [ sys.executable, str(base_dir / "summary_metrics.py"), "--eval", str(run_dir / "eval.json"), "--history", str(history_out), "--run-name", run_dir.name, "--config", str(config_path), "--seed", str(seed), ] ) else: run( [ sys.executable, str(base_dir / "summary_metrics.py"), "--eval", str(run_dir / "eval.json"), "--history", str(base_dir / "results" / "metrics_history.csv"), ] ) if not args.skip_postprocess: cmd = [ sys.executable, str(base_dir / "postprocess_types.py"), "--generated", str(run_dir / "generated.csv"), "--config", str(cfg_for_steps), "--out", str(run_dir / "generated_post.csv"), "--seed", str(seed), ] if ref: cmd += ["--reference", str(ref)] run(cmd) if not args.skip_post_eval: cmd = [ sys.executable, str(base_dir / "evaluate_generated.py"), "--generated", str(run_dir / "generated_post.csv"), "--split", str(split_path), "--stats", str(stats_path), "--vocab", str(vocab_path), "--out", str(run_dir / "eval_post.json"), ] if ref: cmd += ["--reference", str(ref)] run(cmd) if not args.skip_diagnostics: if ref: run( [ sys.executable, str(base_dir / "diagnose_ks.py"), "--generated", str(run_dir / "generated_post.csv"), "--reference", str(ref), ] ) run( [ sys.executable, str(base_dir / "filtered_metrics.py"), "--eval", str(run_dir / "eval_post.json"), "--out", str(run_dir / "filtered_metrics.json"), ] ) run( [ sys.executable, str(base_dir / "ranked_ks.py"), "--eval", str(run_dir / "eval_post.json"), "--out", str(run_dir / "ranked_ks.csv"), ] ) run( [ sys.executable, str(base_dir / "program_stats.py"), "--generated", str(run_dir / "generated_post.csv"), "--config", str(cfg_for_steps), "--reference", str(cfg_for_steps), "--out", str(run_dir / "program_stats.json"), ] ) run( [ sys.executable, str(base_dir / "controller_stats.py"), "--generated", str(run_dir / "generated_post.csv"), "--config", str(cfg_for_steps), "--reference", str(cfg_for_steps), "--out", str(run_dir / "controller_stats.json"), ] ) run( [ sys.executable, str(base_dir / "actuator_stats.py"), "--generated", str(run_dir / "generated_post.csv"), "--config", str(cfg_for_steps), "--reference", str(cfg_for_steps), "--out", str(run_dir / "actuator_stats.json"), ] ) run( [ sys.executable, str(base_dir / "pv_stats.py"), "--generated", str(run_dir / "generated_post.csv"), "--config", str(cfg_for_steps), "--reference", str(cfg_for_steps), "--out", str(run_dir / "pv_stats.json"), ] ) run( [ sys.executable, str(base_dir / "aux_stats.py"), "--generated", str(run_dir / "generated_post.csv"), "--config", str(cfg_for_steps), "--reference", str(cfg_for_steps), "--out", str(run_dir / "aux_stats.json"), ] ) if batch_mode: compute_summary(history_out, summary_out) if __name__ == "__main__": main()