109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Stats for auxiliary/vibration signals (Type 6)."""
|
|
|
|
import argparse
|
|
import csv
|
|
import gzip
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
|
|
def parse_args():
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser = argparse.ArgumentParser(description="Aux stats.")
|
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
|
parser.add_argument("--reference", default=str(base_dir / "config.json"))
|
|
parser.add_argument("--features", default="", help="comma-separated list")
|
|
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
|
parser.add_argument("--out", default=str(base_dir / "results" / "aux_stats.json"))
|
|
parser.add_argument("--max-rows", type=int, default=200000)
|
|
return parser.parse_args()
|
|
|
|
|
|
def resolve_reference_glob(ref_arg: str) -> str:
|
|
ref_path = Path(ref_arg)
|
|
if ref_path.suffix == ".json":
|
|
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
|
|
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
|
if not data_glob:
|
|
raise SystemExit("reference config has no data_glob/data_path")
|
|
combined = ref_path.parent / data_glob
|
|
if "*" in str(combined) or "?" in str(combined):
|
|
return str(combined)
|
|
return str(combined.resolve())
|
|
return str(ref_path)
|
|
|
|
|
|
def read_series(path: Path, cols: List[str], max_rows: int) -> Dict[str, List[float]]:
|
|
vals = {c: [] for c in cols}
|
|
opener = gzip.open if str(path).endswith(".gz") else open
|
|
with opener(path, "rt", newline="") as fh:
|
|
reader = csv.DictReader(fh)
|
|
for i, row in enumerate(reader):
|
|
for c in cols:
|
|
try:
|
|
vals[c].append(float(row[c]))
|
|
except Exception:
|
|
pass
|
|
if max_rows > 0 and i + 1 >= max_rows:
|
|
break
|
|
return vals
|
|
|
|
|
|
def mean_std(series: List[float]):
|
|
if not series:
|
|
return {"mean": None, "std": None, "lag1": None}
|
|
n = len(series)
|
|
mean = sum(series) / n
|
|
var = sum((x - mean) ** 2 for x in series) / max(n - 1, 1)
|
|
std = var ** 0.5
|
|
# lag1 correlation
|
|
if n < 2:
|
|
lag1 = None
|
|
else:
|
|
x = series[:-1]
|
|
y = series[1:]
|
|
mx = sum(x) / len(x)
|
|
my = sum(y) / len(y)
|
|
num = sum((a - mx) * (b - my) for a, b in zip(x, y))
|
|
denx = sum((a - mx) ** 2 for a in x)
|
|
deny = sum((b - my) ** 2 for b in y)
|
|
lag1 = num / (denx ** 0.5 * deny ** 0.5) if denx > 0 and deny > 0 else None
|
|
return {"mean": mean, "std": std, "lag1": lag1}
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
features = [f.strip() for f in args.features.split(",") if f.strip()]
|
|
if not features and Path(args.config).exists():
|
|
cfg = json.loads(Path(args.config).read_text(encoding="utf-8"))
|
|
features = cfg.get("type6_features", []) or []
|
|
if not features:
|
|
raise SystemExit("no features specified for aux_stats")
|
|
|
|
gen_vals = read_series(Path(args.generated), features, args.max_rows)
|
|
ref_glob = resolve_reference_glob(args.reference)
|
|
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
|
|
if not ref_paths:
|
|
raise SystemExit(f"no reference files matched: {ref_glob}")
|
|
real_vals = {c: [] for c in features}
|
|
for p in ref_paths:
|
|
vals = read_series(p, features, args.max_rows)
|
|
for c in features:
|
|
real_vals[c].extend(vals[c])
|
|
|
|
out = {"features": features, "generated": {}, "reference": {}}
|
|
for c in features:
|
|
out["generated"][c] = mean_std(gen_vals[c])
|
|
out["reference"][c] = mean_std(real_vals[c])
|
|
|
|
out_path = Path(args.out)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_text(json.dumps(out, indent=2), encoding="utf-8")
|
|
print("wrote", out_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|