Files
mask-ddpm/example/actuator_stats.py

128 lines
4.2 KiB
Python

#!/usr/bin/env python3
"""Stats for actuator/valve-like outputs (Type 3)."""
import argparse
import csv
import gzip
import json
from pathlib import Path
from typing import Dict, List
def parse_args():
base_dir = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="Actuator/valve stats.")
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--reference", default=str(base_dir / "config.json"))
parser.add_argument("--features", default="", help="comma-separated list")
parser.add_argument("--config", default=str(base_dir / "config.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "actuator_stats.json"))
parser.add_argument("--max-rows", type=int, default=200000)
return parser.parse_args()
def resolve_reference_glob(ref_arg: str) -> str:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_series(path: Path, cols: List[str], max_rows: int) -> Dict[str, List[float]]:
vals = {c: [] for c in cols}
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, row in enumerate(reader):
for c in cols:
try:
vals[c].append(float(row[c]))
except Exception:
pass
if max_rows > 0 and i + 1 >= max_rows:
break
return vals
def spike_stats(series: List[float]):
if not series:
return {
"unique_ratio": None,
"top1_mass": None,
"top3_mass": None,
"median_dwell": None,
}
n = len(series)
# discretize by rounding
rounded = [round(v, 2) for v in series]
counts = {}
for v in rounded:
counts[v] = counts.get(v, 0) + 1
unique_ratio = len(counts) / n
top = sorted(counts.values(), reverse=True)
top1_mass = top[0] / n if top else None
top3_mass = sum(top[:3]) / n if len(top) >= 3 else top1_mass
# dwell length
dwells = []
current = rounded[0]
dwell = 1
for v in rounded[1:]:
if v == current:
dwell += 1
else:
dwells.append(dwell)
current = v
dwell = 1
dwells.append(dwell)
dwells.sort()
median_dwell = dwells[len(dwells) // 2] if dwells else None
return {
"unique_ratio": unique_ratio,
"top1_mass": top1_mass,
"top3_mass": top3_mass,
"median_dwell": median_dwell,
}
def main():
args = parse_args()
features = [f.strip() for f in args.features.split(",") if f.strip()]
if not features and Path(args.config).exists():
cfg = json.loads(Path(args.config).read_text(encoding="utf-8"))
features = cfg.get("type3_features", []) or []
if not features:
raise SystemExit("no features specified for actuator_stats")
gen_vals = read_series(Path(args.generated), features, args.max_rows)
ref_glob = resolve_reference_glob(args.reference)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
if not ref_paths:
raise SystemExit(f"no reference files matched: {ref_glob}")
real_vals = {c: [] for c in features}
for p in ref_paths:
vals = read_series(p, features, args.max_rows)
for c in features:
real_vals[c].extend(vals[c])
out = {"features": features, "generated": {}, "reference": {}}
for c in features:
out["generated"][c] = spike_stats(gen_vals[c])
out["reference"][c] = spike_stats(real_vals[c])
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2), encoding="utf-8")
print("wrote", out_path)
if __name__ == "__main__":
main()