update
This commit is contained in:
@@ -73,6 +73,13 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_torch_state(path: str, device: str):
|
||||
try:
|
||||
return torch.load(path, map_location=device, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=device)
|
||||
|
||||
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
|
||||
@@ -193,9 +200,9 @@ def main():
|
||||
).to(device)
|
||||
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(ema_path, device))
|
||||
else:
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(args.model_path, device))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
@@ -221,7 +228,7 @@ def main():
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
"""One-command pipeline runner with config-driven paths."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -23,6 +25,13 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run full pipeline end-to-end.")
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
||||
parser.add_argument("--configs", default="", help="Comma-separated configs or globs for batch runs")
|
||||
parser.add_argument("--seeds", default="", help="Comma-separated seeds for batch runs")
|
||||
parser.add_argument("--repeat", type=int, default=1, help="Repeat each config with different seeds")
|
||||
parser.add_argument("--runs-root", default="", help="Root directory for per-run artifacts (batch)")
|
||||
parser.add_argument("--benchmark-history", default="", help="CSV path for batch history output")
|
||||
parser.add_argument("--benchmark-summary", default="", help="CSV path for batch summary output")
|
||||
parser.add_argument("--name-prefix", default="", help="Prefix for batch run directory names")
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--reference", default="", help="override reference glob (train*.csv.gz)")
|
||||
parser.add_argument("--skip-prepare", action="store_true")
|
||||
@@ -57,86 +66,433 @@ def resolve_config_path(base_dir: Path, cfg_arg: str) -> Path:
|
||||
raise SystemExit(f"config not found: {cfg_arg}\ntried:\n{tried}")
|
||||
|
||||
|
||||
def resolve_like(base: Path, value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
p = Path(value)
|
||||
if p.is_absolute():
|
||||
return str(p)
|
||||
s = str(value)
|
||||
if any(ch in s for ch in ["*", "?", "["]):
|
||||
return str(base / p)
|
||||
return str((base / p).resolve())
|
||||
|
||||
|
||||
def expand_config_args(base_dir: Path, arg: str):
|
||||
if not arg:
|
||||
return []
|
||||
repo_dir = base_dir.parent
|
||||
tokens = [t.strip() for t in arg.split(",") if t.strip()]
|
||||
out = []
|
||||
for t in tokens:
|
||||
if any(ch in t for ch in ["*", "?", "["]):
|
||||
p = Path(t)
|
||||
if p.is_absolute():
|
||||
base = p.parent
|
||||
pat = p.name
|
||||
out.extend(sorted(base.glob(pat)))
|
||||
else:
|
||||
candidates = [base_dir / p, repo_dir / p, p]
|
||||
matched = False
|
||||
for c in candidates:
|
||||
base = c.parent
|
||||
pat = c.name
|
||||
matches = sorted(base.glob(pat))
|
||||
if matches:
|
||||
out.extend(matches)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
raise SystemExit(f"no configs matched glob: {t}")
|
||||
else:
|
||||
out.append(resolve_config_path(base_dir, t))
|
||||
seen = set()
|
||||
uniq = []
|
||||
for p in out:
|
||||
rp = str(Path(p).resolve())
|
||||
if rp in seen:
|
||||
continue
|
||||
seen.add(rp)
|
||||
uniq.append(Path(rp))
|
||||
return uniq
|
||||
|
||||
|
||||
def parse_seeds(arg: str):
|
||||
if not arg:
|
||||
return []
|
||||
out = []
|
||||
for part in [p.strip() for p in arg.split(",") if p.strip()]:
|
||||
out.append(int(part))
|
||||
return out
|
||||
|
||||
|
||||
def compute_summary(history_path: Path, out_path: Path):
|
||||
if not history_path.exists():
|
||||
return
|
||||
rows = []
|
||||
with history_path.open("r", encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for r in reader:
|
||||
rows.append(r)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
def to_float(v):
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
grouped = {}
|
||||
for r in rows:
|
||||
cfg = r.get("config", "") or ""
|
||||
grouped.setdefault(cfg, []).append(r)
|
||||
|
||||
def mean_std(vals):
|
||||
xs = [x for x in vals if x is not None]
|
||||
if not xs:
|
||||
return None, None
|
||||
mu = sum(xs) / len(xs)
|
||||
if len(xs) <= 1:
|
||||
return mu, 0.0
|
||||
var = sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)
|
||||
return mu, math.sqrt(var)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with out_path.open("w", encoding="utf-8", newline="") as f:
|
||||
fieldnames = [
|
||||
"config",
|
||||
"n_runs",
|
||||
"avg_ks_mean",
|
||||
"avg_ks_std",
|
||||
"avg_jsd_mean",
|
||||
"avg_jsd_std",
|
||||
"avg_lag1_diff_mean",
|
||||
"avg_lag1_diff_std",
|
||||
"best_run_name",
|
||||
"best_avg_ks",
|
||||
]
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for cfg, rs in sorted(grouped.items(), key=lambda kv: kv[0]):
|
||||
kss = [to_float(x.get("avg_ks")) for x in rs]
|
||||
jsds = [to_float(x.get("avg_jsd")) for x in rs]
|
||||
lags = [to_float(x.get("avg_lag1_diff")) for x in rs]
|
||||
ks_mu, ks_sd = mean_std(kss)
|
||||
jsd_mu, jsd_sd = mean_std(jsds)
|
||||
lag_mu, lag_sd = mean_std(lags)
|
||||
best = None
|
||||
for r in rs:
|
||||
ks = to_float(r.get("avg_ks"))
|
||||
if ks is None:
|
||||
continue
|
||||
if best is None or ks < best[0]:
|
||||
best = (ks, r.get("run_name", ""))
|
||||
writer.writerow(
|
||||
{
|
||||
"config": cfg,
|
||||
"n_runs": len(rs),
|
||||
"avg_ks_mean": ks_mu,
|
||||
"avg_ks_std": ks_sd,
|
||||
"avg_jsd_mean": jsd_mu,
|
||||
"avg_jsd_std": jsd_sd,
|
||||
"avg_lag1_diff_mean": lag_mu,
|
||||
"avg_lag1_diff_std": lag_sd,
|
||||
"best_run_name": "" if best is None else best[1],
|
||||
"best_avg_ks": None if best is None else best[0],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
config_path = resolve_config_path(base_dir, args.config)
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
timesteps = cfg.get("timesteps", 200)
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
seed_list = parse_seeds(args.seeds)
|
||||
config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)]
|
||||
batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary)
|
||||
|
||||
if not args.skip_prepare:
|
||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||
if not args.skip_train:
|
||||
run([sys.executable, str(base_dir / "train.py"), "--config", str(config_path), "--device", args.device])
|
||||
if not args.skip_export:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "export_samples.py"),
|
||||
"--include-time",
|
||||
"--device",
|
||||
args.device,
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--timesteps",
|
||||
str(timesteps),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--clip-k",
|
||||
str(clip_k),
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if not args.skip_eval:
|
||||
if ref:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||
else:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
run([sys.executable, str(base_dir / "summary_metrics.py")])
|
||||
|
||||
if not args.skip_postprocess:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "postprocess_types.py"),
|
||||
"--generated",
|
||||
str(base_dir / "results" / "generated.csv"),
|
||||
"--config",
|
||||
str(config_path),
|
||||
]
|
||||
if ref:
|
||||
cmd += ["--reference", str(ref)]
|
||||
run(cmd)
|
||||
runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs")
|
||||
history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv")
|
||||
summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv")
|
||||
|
||||
if not args.skip_post_eval:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_generated.py"),
|
||||
"--generated",
|
||||
str(base_dir / "results" / "generated_post.csv"),
|
||||
"--out",
|
||||
"results/eval_post.json",
|
||||
]
|
||||
if ref:
|
||||
cmd += ["--reference", str(ref)]
|
||||
run(cmd)
|
||||
for config_path in config_paths:
|
||||
cfg_base = config_path.parent
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
if not args.skip_diagnostics:
|
||||
if ref:
|
||||
run([sys.executable, str(base_dir / "diagnose_ks.py"), "--generated", str(base_dir / "results" / "generated_post.csv"), "--reference", str(ref)])
|
||||
run([sys.executable, str(base_dir / "filtered_metrics.py"), "--eval", str(base_dir / "results" / "eval_post.json")])
|
||||
run([sys.executable, str(base_dir / "ranked_ks.py"), "--eval", str(base_dir / "results" / "eval_post.json")])
|
||||
run([sys.executable, str(base_dir / "program_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
|
||||
run([sys.executable, str(base_dir / "controller_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
|
||||
run([sys.executable, str(base_dir / "actuator_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
|
||||
run([sys.executable, str(base_dir / "pv_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
|
||||
run([sys.executable, str(base_dir / "aux_stats.py"), "--config", str(config_path), "--reference", str(ref or config_path)])
|
||||
timesteps = cfg.get("timesteps", 200)
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
|
||||
seeds = seed_list
|
||||
if not seeds:
|
||||
base_seed = int(cfg.get("seed", 1337))
|
||||
if args.repeat and args.repeat > 1:
|
||||
seeds = [base_seed + i for i in range(int(args.repeat))]
|
||||
else:
|
||||
seeds = [base_seed]
|
||||
|
||||
for seed in seeds:
|
||||
run_dir = base_dir / "results" if not batch_mode else (runs_root / f"{args.name_prefix}{config_path.stem}__seed{seed}")
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data_path = resolve_like(cfg_base, str(cfg.get("data_path", "")))
|
||||
data_glob = resolve_like(cfg_base, str(cfg.get("data_glob", "")))
|
||||
split_path = resolve_like(cfg_base, str(cfg.get("split_path", ""))) or str(base_dir / "feature_split.json")
|
||||
stats_path = resolve_like(cfg_base, str(cfg.get("stats_path", ""))) or str(base_dir / "results" / "cont_stats.json")
|
||||
vocab_path = resolve_like(cfg_base, str(cfg.get("vocab_path", ""))) or str(base_dir / "results" / "disc_vocab.json")
|
||||
|
||||
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
ref = resolve_like(cfg_base, str(ref)) if ref else ""
|
||||
|
||||
if not args.skip_train:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "train.py"),
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--device",
|
||||
args.device,
|
||||
"--out-dir",
|
||||
str(run_dir),
|
||||
"--seed",
|
||||
str(seed),
|
||||
]
|
||||
)
|
||||
|
||||
config_used = run_dir / "config_used.json"
|
||||
cfg_for_steps = config_used if config_used.exists() else config_path
|
||||
|
||||
if not args.skip_export:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "export_samples.py"),
|
||||
"--include-time",
|
||||
"--device",
|
||||
args.device,
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--data-path",
|
||||
str(data_path),
|
||||
"--data-glob",
|
||||
str(data_glob),
|
||||
"--split-path",
|
||||
str(split_path),
|
||||
"--stats-path",
|
||||
str(stats_path),
|
||||
"--vocab-path",
|
||||
str(vocab_path),
|
||||
"--model-path",
|
||||
str(run_dir / "model.pt"),
|
||||
"--out",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--timesteps",
|
||||
str(timesteps),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--clip-k",
|
||||
str(clip_k),
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
|
||||
if not args.skip_eval:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_generated.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "eval.json"),
|
||||
]
|
||||
if ref:
|
||||
cmd += ["--reference", str(ref)]
|
||||
run(cmd)
|
||||
if batch_mode:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "summary_metrics.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval.json"),
|
||||
"--history",
|
||||
str(history_out),
|
||||
"--run-name",
|
||||
run_dir.name,
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--seed",
|
||||
str(seed),
|
||||
]
|
||||
)
|
||||
else:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "summary_metrics.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval.json"),
|
||||
"--history",
|
||||
str(base_dir / "results" / "metrics_history.csv"),
|
||||
]
|
||||
)
|
||||
|
||||
if not args.skip_postprocess:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "postprocess_types.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--seed",
|
||||
str(seed),
|
||||
]
|
||||
if ref:
|
||||
cmd += ["--reference", str(ref)]
|
||||
run(cmd)
|
||||
|
||||
if not args.skip_post_eval:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_generated.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "eval_post.json"),
|
||||
]
|
||||
if ref:
|
||||
cmd += ["--reference", str(ref)]
|
||||
run(cmd)
|
||||
|
||||
if not args.skip_diagnostics:
|
||||
if ref:
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "diagnose_ks.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--reference",
|
||||
str(ref),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "filtered_metrics.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval_post.json"),
|
||||
"--out",
|
||||
str(run_dir / "filtered_metrics.json"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "ranked_ks.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval_post.json"),
|
||||
"--out",
|
||||
str(run_dir / "ranked_ks.csv"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "program_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--reference",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "program_stats.json"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "controller_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--reference",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "controller_stats.json"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "actuator_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--reference",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "actuator_stats.json"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "pv_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--reference",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "pv_stats.json"),
|
||||
]
|
||||
)
|
||||
run(
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "aux_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps),
|
||||
"--reference",
|
||||
str(cfg_for_steps),
|
||||
"--out",
|
||||
str(run_dir / "aux_stats.json"),
|
||||
]
|
||||
)
|
||||
|
||||
if batch_mode:
|
||||
compute_summary(history_out, summary_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -29,6 +29,13 @@ BATCH_SIZE = 2
|
||||
CLIP_K = 5.0
|
||||
|
||||
|
||||
def load_torch_state(path: str, device: str):
|
||||
try:
|
||||
return torch.load(path, map_location=device, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=device)
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
|
||||
return json.load(f)["vocab"]
|
||||
@@ -110,7 +117,7 @@ def main():
|
||||
eps_scale=eps_scale,
|
||||
).to(DEVICE)
|
||||
if MODEL_PATH.exists():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(str(MODEL_PATH), DEVICE))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
@@ -136,7 +143,7 @@ def main():
|
||||
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), DEVICE))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||
|
||||
@@ -1,41 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Print average metrics from eval.json and compare with previous run."""
|
||||
"""Print average metrics from eval.json and append to a history CSV."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def mean(values):
|
||||
return sum(values) / len(values) if values else None
|
||||
|
||||
|
||||
def parse_last_row(history_path: Path):
|
||||
def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser = argparse.ArgumentParser(description="Summarize eval.json into a history CSV.")
|
||||
parser.add_argument("--eval", dest="eval_path", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--history", default=str(base_dir / "results" / "metrics_history.csv"))
|
||||
parser.add_argument("--run-name", default="")
|
||||
parser.add_argument("--config", default="")
|
||||
parser.add_argument("--seed", type=int, default=-1)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_last_row(history_path: Path) -> Optional[dict]:
|
||||
if not history_path.exists():
|
||||
return None
|
||||
rows = history_path.read_text(encoding="utf-8").strip().splitlines()
|
||||
if len(rows) < 2:
|
||||
with history_path.open("r", encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
if not rows:
|
||||
return None
|
||||
for line in reversed(rows[1:]):
|
||||
parts = line.split(",")
|
||||
if len(parts) < 4:
|
||||
continue
|
||||
try:
|
||||
return {
|
||||
"avg_ks": float(parts[1]),
|
||||
"avg_jsd": float(parts[2]),
|
||||
"avg_lag1_diff": float(parts[3]),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
last = rows[-1]
|
||||
for key in ["avg_ks", "avg_jsd", "avg_lag1_diff"]:
|
||||
if key in last and last[key] not in [None, ""]:
|
||||
try:
|
||||
last[key] = float(last[key])
|
||||
except Exception:
|
||||
last[key] = None
|
||||
return last
|
||||
|
||||
|
||||
def ensure_header(history_path: Path, fieldnames):
|
||||
if history_path.exists():
|
||||
return
|
||||
history_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with history_path.open("w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
def main():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
eval_path = base_dir / "results" / "eval.json"
|
||||
args = parse_args()
|
||||
eval_path = Path(args.eval_path)
|
||||
if not eval_path.exists():
|
||||
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||
history_path = Path(args.history)
|
||||
|
||||
obj = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
ks = list(obj.get("continuous_ks", {}).values())
|
||||
@@ -46,22 +67,48 @@ def main():
|
||||
avg_jsd = mean(jsd)
|
||||
avg_lag1 = mean(lag)
|
||||
|
||||
history_path = base_dir / "results" / "metrics_history.csv"
|
||||
prev = parse_last_row(history_path)
|
||||
obj["avg_ks"] = avg_ks
|
||||
obj["avg_jsd"] = avg_jsd
|
||||
obj["avg_lag1_diff"] = avg_lag1
|
||||
eval_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
||||
|
||||
if not history_path.exists():
|
||||
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
|
||||
with history_path.open("a", encoding="utf-8") as f:
|
||||
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
|
||||
prev = read_last_row(history_path)
|
||||
|
||||
fieldnames = ["timestamp", "avg_ks", "avg_jsd", "avg_lag1_diff"]
|
||||
extended = any([args.run_name, args.config, args.seed >= 0])
|
||||
if extended:
|
||||
fieldnames = ["timestamp", "run_name", "config", "seed", "avg_ks", "avg_jsd", "avg_lag1_diff"]
|
||||
ensure_header(history_path, fieldnames)
|
||||
|
||||
row = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"avg_ks": avg_ks,
|
||||
"avg_jsd": avg_jsd,
|
||||
"avg_lag1_diff": avg_lag1,
|
||||
}
|
||||
if extended:
|
||||
row["run_name"] = args.run_name
|
||||
row["config"] = args.config
|
||||
row["seed"] = args.seed if args.seed >= 0 else ""
|
||||
|
||||
with history_path.open("a", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writerow(row)
|
||||
|
||||
print("avg_ks", avg_ks)
|
||||
print("avg_jsd", avg_jsd)
|
||||
print("avg_lag1_diff", avg_lag1)
|
||||
|
||||
if prev is not None:
|
||||
print("delta_avg_ks", avg_ks - prev["avg_ks"])
|
||||
print("delta_avg_jsd", avg_jsd - prev["avg_jsd"])
|
||||
print("delta_avg_lag1_diff", avg_lag1 - prev["avg_lag1_diff"])
|
||||
pks = prev.get("avg_ks")
|
||||
pjsd = prev.get("avg_jsd")
|
||||
plag = prev.get("avg_lag1_diff")
|
||||
if pks is not None:
|
||||
print("delta_avg_ks", avg_ks - pks)
|
||||
if pjsd is not None:
|
||||
print("delta_avg_jsd", avg_jsd - pjsd)
|
||||
if plag is not None:
|
||||
print("delta_avg_lag1_diff", avg_lag1 - plag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -108,6 +108,8 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -168,6 +170,14 @@ def main():
|
||||
# 优先使用命令行传入的device参数
|
||||
if args.device != "auto":
|
||||
config["device"] = args.device
|
||||
if args.out_dir:
|
||||
out_dir = Path(args.out_dir)
|
||||
if not out_dir.is_absolute():
|
||||
base = Path(args.config).resolve().parent if args.config else BASE_DIR
|
||||
out_dir = resolve_path(base, out_dir)
|
||||
config["out_dir"] = str(out_dir)
|
||||
if args.seed is not None:
|
||||
config["seed"] = int(args.seed)
|
||||
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user