Files
mask-ddpm/example/run_all.py
Mingzhe Yang 10c0721ee1 update
2026-02-04 03:53:17 +08:00

500 lines
18 KiB
Python

#!/usr/bin/env python3
"""One-command pipeline runner with config-driven paths."""
import argparse
import csv
import json
import math
import subprocess
import sys
from pathlib import Path
from platform_utils import safe_path, is_windows
def run(cmd):
print("running:", " ".join(cmd))
cmd = [safe_path(arg) for arg in cmd]
if is_windows():
subprocess.run(cmd, check=True, shell=False)
else:
subprocess.run(cmd, check=True)
def parse_args():
parser = argparse.ArgumentParser(description="Run full pipeline end-to-end.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--config", default=str(base_dir / "config.json"))
parser.add_argument("--configs", default="", help="Comma-separated configs or globs for batch runs")
parser.add_argument("--seeds", default="", help="Comma-separated seeds for batch runs")
parser.add_argument("--repeat", type=int, default=1, help="Repeat each config with different seeds")
parser.add_argument("--runs-root", default="", help="Root directory for per-run artifacts (batch)")
parser.add_argument("--benchmark-history", default="", help="CSV path for batch history output")
parser.add_argument("--benchmark-summary", default="", help="CSV path for batch summary output")
parser.add_argument("--name-prefix", default="", help="Prefix for batch run directory names")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--reference", default="", help="override reference glob (train*.csv.gz)")
parser.add_argument("--skip-prepare", action="store_true")
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
return parser.parse_args()
def resolve_config_path(base_dir: Path, cfg_arg: str) -> Path:
p = Path(cfg_arg)
if p.is_absolute():
if p.exists():
return p.resolve()
raise SystemExit(f"config not found: {p}")
repo_dir = base_dir.parent
candidates = [p, base_dir / p, repo_dir / p]
if p.parts and p.parts[0] == "example":
trimmed = Path(*p.parts[1:]) if len(p.parts) > 1 else Path()
if str(trimmed):
candidates.extend([base_dir / trimmed, repo_dir / trimmed])
for c in candidates:
if c.exists():
return c.resolve()
tried = "\n".join(str(c) for c in candidates)
raise SystemExit(f"config not found: {cfg_arg}\ntried:\n{tried}")
def resolve_like(base: Path, value: str) -> str:
if not value:
return ""
p = Path(value)
if p.is_absolute():
return str(p)
s = str(value)
if any(ch in s for ch in ["*", "?", "["]):
return str(base / p)
return str((base / p).resolve())
def expand_config_args(base_dir: Path, arg: str):
if not arg:
return []
repo_dir = base_dir.parent
tokens = [t.strip() for t in arg.split(",") if t.strip()]
out = []
for t in tokens:
if any(ch in t for ch in ["*", "?", "["]):
p = Path(t)
if p.is_absolute():
base = p.parent
pat = p.name
out.extend(sorted(base.glob(pat)))
else:
candidates = [base_dir / p, repo_dir / p, p]
matched = False
for c in candidates:
base = c.parent
pat = c.name
matches = sorted(base.glob(pat))
if matches:
out.extend(matches)
matched = True
break
if not matched:
raise SystemExit(f"no configs matched glob: {t}")
else:
out.append(resolve_config_path(base_dir, t))
seen = set()
uniq = []
for p in out:
rp = str(Path(p).resolve())
if rp in seen:
continue
seen.add(rp)
uniq.append(Path(rp))
return uniq
def parse_seeds(arg: str):
if not arg:
return []
out = []
for part in [p.strip() for p in arg.split(",") if p.strip()]:
out.append(int(part))
return out
def compute_summary(history_path: Path, out_path: Path):
if not history_path.exists():
return
rows = []
with history_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(r)
if not rows:
return
def to_float(v):
try:
return float(v)
except Exception:
return None
grouped = {}
for r in rows:
cfg = r.get("config", "") or ""
grouped.setdefault(cfg, []).append(r)
def mean_std(vals):
xs = [x for x in vals if x is not None]
if not xs:
return None, None
mu = sum(xs) / len(xs)
if len(xs) <= 1:
return mu, 0.0
var = sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)
return mu, math.sqrt(var)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8", newline="") as f:
fieldnames = [
"config",
"n_runs",
"avg_ks_mean",
"avg_ks_std",
"avg_jsd_mean",
"avg_jsd_std",
"avg_lag1_diff_mean",
"avg_lag1_diff_std",
"best_run_name",
"best_avg_ks",
]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for cfg, rs in sorted(grouped.items(), key=lambda kv: kv[0]):
kss = [to_float(x.get("avg_ks")) for x in rs]
jsds = [to_float(x.get("avg_jsd")) for x in rs]
lags = [to_float(x.get("avg_lag1_diff")) for x in rs]
ks_mu, ks_sd = mean_std(kss)
jsd_mu, jsd_sd = mean_std(jsds)
lag_mu, lag_sd = mean_std(lags)
best = None
for r in rs:
ks = to_float(r.get("avg_ks"))
if ks is None:
continue
if best is None or ks < best[0]:
best = (ks, r.get("run_name", ""))
writer.writerow(
{
"config": cfg,
"n_runs": len(rs),
"avg_ks_mean": ks_mu,
"avg_ks_std": ks_sd,
"avg_jsd_mean": jsd_mu,
"avg_jsd_std": jsd_sd,
"avg_lag1_diff_mean": lag_mu,
"avg_lag1_diff_std": lag_sd,
"best_run_name": "" if best is None else best[1],
"best_avg_ks": None if best is None else best[0],
}
)
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
seed_list = parse_seeds(args.seeds)
config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)]
batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary)
if not args.skip_prepare:
run([sys.executable, str(base_dir / "prepare_data.py")])
runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs")
history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv")
summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv")
for config_path in config_paths:
cfg_base = config_path.parent
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
timesteps = cfg.get("timesteps", 200)
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
clip_k = cfg.get("clip_k", 5.0)
seeds = seed_list
if not seeds:
base_seed = int(cfg.get("seed", 1337))
if args.repeat and args.repeat > 1:
seeds = [base_seed + i for i in range(int(args.repeat))]
else:
seeds = [base_seed]
for seed in seeds:
run_dir = base_dir / "results" if not batch_mode else (runs_root / f"{args.name_prefix}{config_path.stem}__seed{seed}")
run_dir.mkdir(parents=True, exist_ok=True)
data_path = resolve_like(cfg_base, str(cfg.get("data_path", "")))
data_glob = resolve_like(cfg_base, str(cfg.get("data_glob", "")))
split_path = resolve_like(cfg_base, str(cfg.get("split_path", ""))) or str(base_dir / "feature_split.json")
stats_path = resolve_like(cfg_base, str(cfg.get("stats_path", ""))) or str(base_dir / "results" / "cont_stats.json")
vocab_path = resolve_like(cfg_base, str(cfg.get("vocab_path", ""))) or str(base_dir / "results" / "disc_vocab.json")
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
ref = resolve_like(cfg_base, str(ref)) if ref else ""
if not args.skip_train:
run(
[
sys.executable,
str(base_dir / "train.py"),
"--config",
str(config_path),
"--device",
args.device,
"--out-dir",
str(run_dir),
"--seed",
str(seed),
]
)
config_used = run_dir / "config_used.json"
cfg_for_steps = config_used if config_used.exists() else config_path
if not args.skip_export:
run(
[
sys.executable,
str(base_dir / "export_samples.py"),
"--include-time",
"--device",
args.device,
"--config",
str(cfg_for_steps),
"--data-path",
str(data_path),
"--data-glob",
str(data_glob),
"--split-path",
str(split_path),
"--stats-path",
str(stats_path),
"--vocab-path",
str(vocab_path),
"--model-path",
str(run_dir / "model.pt"),
"--out",
str(run_dir / "generated.csv"),
"--timesteps",
str(timesteps),
"--seq-len",
str(seq_len),
"--batch-size",
str(batch_size),
"--clip-k",
str(clip_k),
"--use-ema",
]
)
if not args.skip_eval:
cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval.json"),
]
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if batch_mode:
run(
[
sys.executable,
str(base_dir / "summary_metrics.py"),
"--eval",
str(run_dir / "eval.json"),
"--history",
str(history_out),
"--run-name",
run_dir.name,
"--config",
str(config_path),
"--seed",
str(seed),
]
)
else:
run(
[
sys.executable,
str(base_dir / "summary_metrics.py"),
"--eval",
str(run_dir / "eval.json"),
"--history",
str(base_dir / "results" / "metrics_history.csv"),
]
)
if not args.skip_postprocess:
cmd = [
sys.executable,
str(base_dir / "postprocess_types.py"),
"--generated",
str(run_dir / "generated.csv"),
"--config",
str(cfg_for_steps),
"--out",
str(run_dir / "generated_post.csv"),
"--seed",
str(seed),
]
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if not args.skip_post_eval:
cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval_post.json"),
]
if ref:
cmd += ["--reference", str(ref)]
run(cmd)
if not args.skip_diagnostics:
if ref:
run(
[
sys.executable,
str(base_dir / "diagnose_ks.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--reference",
str(ref),
]
)
run(
[
sys.executable,
str(base_dir / "filtered_metrics.py"),
"--eval",
str(run_dir / "eval_post.json"),
"--out",
str(run_dir / "filtered_metrics.json"),
]
)
run(
[
sys.executable,
str(base_dir / "ranked_ks.py"),
"--eval",
str(run_dir / "eval_post.json"),
"--out",
str(run_dir / "ranked_ks.csv"),
]
)
run(
[
sys.executable,
str(base_dir / "program_stats.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--config",
str(cfg_for_steps),
"--reference",
str(cfg_for_steps),
"--out",
str(run_dir / "program_stats.json"),
]
)
run(
[
sys.executable,
str(base_dir / "controller_stats.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--config",
str(cfg_for_steps),
"--reference",
str(cfg_for_steps),
"--out",
str(run_dir / "controller_stats.json"),
]
)
run(
[
sys.executable,
str(base_dir / "actuator_stats.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--config",
str(cfg_for_steps),
"--reference",
str(cfg_for_steps),
"--out",
str(run_dir / "actuator_stats.json"),
]
)
run(
[
sys.executable,
str(base_dir / "pv_stats.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--config",
str(cfg_for_steps),
"--reference",
str(cfg_for_steps),
"--out",
str(run_dir / "pv_stats.json"),
]
)
run(
[
sys.executable,
str(base_dir / "aux_stats.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--config",
str(cfg_for_steps),
"--reference",
str(cfg_for_steps),
"--out",
str(run_dir / "aux_stats.json"),
]
)
if batch_mode:
compute_summary(history_out, summary_out)
if __name__ == "__main__":
main()