优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
142
example/program_stats.py
Normal file
142
example/program_stats.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compute program-style stats (dwell, change count, step size) for selected features."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser = argparse.ArgumentParser(description="Program stats for setpoints/demands.")
|
||||
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
||||
parser.add_argument("--reference", default=str(base_dir / "config.json"))
|
||||
parser.add_argument("--features", default="", help="comma-separated list; empty = auto from eval")
|
||||
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "program_stats.json"))
|
||||
parser.add_argument("--max-rows", type=int, default=200000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_reference_glob(ref_arg: str) -> str:
|
||||
ref_path = Path(ref_arg)
|
||||
if ref_path.suffix == ".json":
|
||||
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
|
||||
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if not data_glob:
|
||||
raise SystemExit("reference config has no data_glob/data_path")
|
||||
combined = ref_path.parent / data_glob
|
||||
# avoid resolve on glob patterns
|
||||
if "*" in str(combined) or "?" in str(combined):
|
||||
return str(combined)
|
||||
return str(combined.resolve())
|
||||
return str(ref_path)
|
||||
|
||||
|
||||
def read_series(path: Path, cols: List[str], max_rows: int) -> Dict[str, List[float]]:
|
||||
vals = {c: [] for c in cols}
|
||||
opener = gzip.open if str(path).endswith(".gz") else open
|
||||
with opener(path, "rt", newline="") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for i, row in enumerate(reader):
|
||||
for c in cols:
|
||||
try:
|
||||
vals[c].append(float(row[c]))
|
||||
except Exception:
|
||||
pass
|
||||
if max_rows > 0 and i + 1 >= max_rows:
|
||||
break
|
||||
return vals
|
||||
|
||||
|
||||
def dwell_and_steps(series: List[float]):
|
||||
if not series:
|
||||
return {
|
||||
"num_changes": 0,
|
||||
"mean_dwell": None,
|
||||
"median_dwell": None,
|
||||
"mean_step": None,
|
||||
"median_step": None,
|
||||
}
|
||||
changes = 0
|
||||
dwells = []
|
||||
steps = []
|
||||
current = series[0]
|
||||
dwell = 1
|
||||
for v in series[1:]:
|
||||
if v == current:
|
||||
dwell += 1
|
||||
continue
|
||||
changes += 1
|
||||
dwells.append(dwell)
|
||||
steps.append(abs(v - current))
|
||||
current = v
|
||||
dwell = 1
|
||||
dwells.append(dwell)
|
||||
|
||||
def mean(x):
|
||||
return sum(x) / len(x) if x else None
|
||||
|
||||
def median(x):
|
||||
if not x:
|
||||
return None
|
||||
xs = sorted(x)
|
||||
mid = len(xs) // 2
|
||||
return xs[mid] if len(xs) % 2 == 1 else 0.5 * (xs[mid - 1] + xs[mid])
|
||||
|
||||
return {
|
||||
"num_changes": changes,
|
||||
"mean_dwell": mean(dwells),
|
||||
"median_dwell": median(dwells),
|
||||
"mean_step": mean(steps),
|
||||
"median_step": median(steps),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
out_path = Path(args.out)
|
||||
|
||||
eval_path = Path("results") / "eval.json"
|
||||
auto_feats = []
|
||||
if eval_path.exists():
|
||||
data = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
ks = data.get("continuous_ks", {})
|
||||
auto_feats = [k for k, v in ks.items() if v >= 0.6]
|
||||
|
||||
features = [f.strip() for f in args.features.split(",") if f.strip()] or auto_feats
|
||||
if not features and Path(args.config).exists():
|
||||
cfg = json.loads(Path(args.config).read_text(encoding="utf-8"))
|
||||
features = cfg.get("type1_features", []) or []
|
||||
if not features:
|
||||
raise SystemExit("no features specified and no eval.json with ks>=0.6")
|
||||
|
||||
# generated series
|
||||
gen_vals = read_series(Path(args.generated), features, args.max_rows)
|
||||
|
||||
# reference series (aggregate across files)
|
||||
ref_glob = resolve_reference_glob(args.reference)
|
||||
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
|
||||
if not ref_paths:
|
||||
raise SystemExit(f"no reference files matched: {ref_glob}")
|
||||
real_vals = {c: [] for c in features}
|
||||
for p in ref_paths:
|
||||
vals = read_series(p, features, args.max_rows)
|
||||
for c in features:
|
||||
real_vals[c].extend(vals[c])
|
||||
|
||||
out = {"features": features, "generated": {}, "reference": {}}
|
||||
for c in features:
|
||||
out["generated"][c] = dwell_and_steps(gen_vals[c])
|
||||
out["reference"][c] = dwell_and_steps(real_vals[c])
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(json.dumps(out, indent=2), encoding="utf-8")
|
||||
print("wrote", out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user