Files
mask-ddpm/example/run_submission_resume.py
2026-04-18 19:01:25 +08:00

400 lines
14 KiB
Python

#!/usr/bin/env python3
"""One-command full pipeline runner with safe resume and stage skipping."""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from typing import Dict, List
from platform_utils import is_windows, safe_path
def run(cmd: List[str]) -> None:
print("running:", " ".join(cmd))
cmd = [safe_path(arg) for arg in cmd]
if is_windows():
subprocess.run(cmd, check=True, shell=False)
else:
subprocess.run(cmd, check=True)
def parse_args():
base_dir = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> eval with resume-aware staging.")
parser.add_argument("--config", default=str(base_dir / "config_submission_full.json"))
parser.add_argument("--device", default="auto")
parser.add_argument("--run-dir", default=str(base_dir / "results" / "submission_full"))
parser.add_argument("--reference", default="")
parser.add_argument("--no-resume", action="store_true", help="Do not auto-skip completed stages or resume from ckpt.")
parser.add_argument("--skip-prepare", action="store_true")
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-comprehensive-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
return parser.parse_args()
def load_state(path: Path) -> Dict[str, str]:
if not path.exists():
return {}
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
def save_state(path: Path, state: Dict[str, str]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8")
def stage_complete(state: Dict[str, str], stage: str, outputs: List[Path], resume: bool) -> bool:
if not resume:
return False
if outputs and all(p.exists() for p in outputs):
return True
return state.get(stage) == "done"
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
config_path = Path(args.config)
if not config_path.is_absolute():
config_path = (base_dir / config_path).resolve()
run_dir = Path(args.run_dir)
if not run_dir.is_absolute():
run_dir = (base_dir / run_dir).resolve()
run_dir.mkdir(parents=True, exist_ok=True)
cfg = json.loads(config_path.read_text(encoding="utf-8"))
cfg_base = config_path.parent
def abs_cfg_like(value: str) -> str:
p = Path(value)
if p.is_absolute():
return str(p)
if any(ch in value for ch in ["*", "?", "["]):
return str(cfg_base / p)
return str((cfg_base / p).resolve())
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
if ref:
ref = abs_cfg_like(str(ref))
timesteps = int(cfg.get("timesteps", 200))
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", 64)))
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", 2)))
clip_k = float(cfg.get("clip_k", 5.0))
split_path = abs_cfg_like(str(cfg.get("split_path", "./feature_split.json")))
stats_path = abs_cfg_like(str(cfg.get("stats_path", "./results/cont_stats.json")))
vocab_path = abs_cfg_like(str(cfg.get("vocab_path", "./results/disc_vocab.json")))
data_path = abs_cfg_like(str(cfg.get("data_path", ""))) if cfg.get("data_path") else ""
data_glob = abs_cfg_like(str(cfg.get("data_glob", ""))) if cfg.get("data_glob") else ""
state_path = run_dir / "pipeline_state.json"
state = load_state(state_path)
resume = not args.no_resume
cfg_for_steps = run_dir / "config_used.json"
stage_defs = []
if not args.skip_prepare:
stage_defs.append(
(
"prepare",
[Path(stats_path), Path(vocab_path)],
[sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)],
)
)
if not args.skip_train:
train_cmd = [
sys.executable,
str(base_dir / "train_resume.py"),
"--config",
str(config_path),
"--device",
args.device,
"--out-dir",
str(run_dir),
"--seed",
str(int(cfg.get("seed", 1337))),
]
if resume:
train_cmd.append("--resume")
stage_defs.append(("train", [run_dir / "model.pt"], train_cmd))
if not args.skip_export:
stage_defs.append(
(
"export",
[run_dir / "generated.csv"],
[
sys.executable,
str(base_dir / "export_samples_resume.py"),
"--include-time",
"--device",
args.device,
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--data-path",
str(data_path),
"--data-glob",
str(data_glob),
"--split-path",
str(split_path),
"--stats-path",
str(stats_path),
"--vocab-path",
str(vocab_path),
"--model-path",
str(run_dir / "model.pt"),
"--out",
str(run_dir / "generated.csv"),
"--timesteps",
str(timesteps),
"--seq-len",
str(seq_len),
"--batch-size",
str(batch_size),
"--clip-k",
str(clip_k),
"--use-ema",
],
)
)
if not args.skip_eval:
eval_cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval.json"),
]
if ref:
eval_cmd += ["--reference", str(ref)]
stage_defs.append(("eval", [run_dir / "eval.json"], eval_cmd))
if not args.skip_comprehensive_eval:
stage_defs.append(
(
"comprehensive_eval",
[run_dir / "comprehensive_eval.json"],
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval.json"),
"--device",
args.device,
],
)
)
if not args.skip_postprocess:
post_cmd = [
sys.executable,
str(base_dir / "postprocess_types.py"),
"--generated",
str(run_dir / "generated.csv"),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--out",
str(run_dir / "generated_post.csv"),
"--seed",
str(int(cfg.get("seed", 1337))),
]
if ref:
post_cmd += ["--reference", str(ref)]
stage_defs.append(("postprocess", [run_dir / "generated_post.csv"], post_cmd))
if not args.skip_post_eval:
post_eval_cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval_post.json"),
]
if ref:
post_eval_cmd += ["--reference", str(ref)]
stage_defs.append(("post_eval", [run_dir / "eval_post.json"], post_eval_cmd))
if not args.skip_comprehensive_eval:
stage_defs.append(
(
"comprehensive_post_eval",
[run_dir / "comprehensive_eval_post.json"],
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval_post.json"),
"--device",
args.device,
],
)
)
if not args.skip_diagnostics:
stage_defs.extend(
[
(
"filtered_metrics",
[run_dir / "filtered_metrics.json"],
[
sys.executable,
str(base_dir / "filtered_metrics.py"),
"--eval",
str(run_dir / "eval.json"),
"--out",
str(run_dir / "filtered_metrics.json"),
],
),
(
"ranked_ks",
[run_dir / "ranked_ks.csv"],
[
sys.executable,
str(base_dir / "ranked_ks.py"),
"--eval",
str(run_dir / "eval.json"),
"--out",
str(run_dir / "ranked_ks.csv"),
],
),
(
"program_stats",
[run_dir / "program_stats.json"],
[
sys.executable,
str(base_dir / "program_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"controller_stats",
[run_dir / "controller_stats.json"],
[
sys.executable,
str(base_dir / "controller_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"actuator_stats",
[run_dir / "actuator_stats.json"],
[
sys.executable,
str(base_dir / "actuator_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"pv_stats",
[run_dir / "pv_stats.json"],
[
sys.executable,
str(base_dir / "pv_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"aux_stats",
[run_dir / "aux_stats.json"],
[
sys.executable,
str(base_dir / "aux_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
]
)
command_log = run_dir / "run_commands.txt"
if not command_log.exists():
command_log.write_text("", encoding="utf-8")
for stage, outputs, cmd in stage_defs:
if stage_complete(state, stage, outputs, resume):
print(f"skip_stage {stage}: outputs already present")
state[stage] = "done"
save_state(state_path, state)
continue
state[stage] = "running"
save_state(state_path, state)
with command_log.open("a", encoding="utf-8") as fh:
fh.write(stage + ": " + " ".join(cmd) + "\n")
run(cmd)
state[stage] = "done"
save_state(state_path, state)
print(f"pipeline_complete run_dir={run_dir}")
if __name__ == "__main__":
main()