#!/usr/bin/env python3 """Shared utilities for evaluation, downstream utility, and ablations.""" from __future__ import annotations import csv import gzip import json import math import random from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np def load_json(path: str | Path) -> Dict: with open(path, "r", encoding="utf-8") as f: return json.load(f) def open_csv(path: str | Path): path = str(path) if path.endswith(".gz"): return gzip.open(path, "rt", newline="") return open(path, "r", newline="") def resolve_path(base_dir: Path, path_like: str | Path) -> Path: path = Path(path_like) if path.is_absolute(): return path return (base_dir / path).resolve() def resolve_reference_paths(ref_arg: str | Path) -> List[str]: ref_path = Path(ref_arg) if ref_path.suffix == ".json" and ref_path.exists(): cfg = load_json(ref_path) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") combined = ref_path.parent / data_glob return expand_glob_or_file(combined) return expand_glob_or_file(ref_path) def infer_test_paths(ref_arg: str | Path) -> List[str]: ref_path = Path(ref_arg) if ref_path.suffix == ".json" and ref_path.exists(): cfg = load_json(ref_path) cfg_base = ref_path.parent explicit = cfg.get("test_glob") or cfg.get("test_path") or "" if explicit: return expand_glob_or_file(cfg_base / explicit) train_ref = cfg.get("data_glob") or cfg.get("data_path") or "" if not train_ref: raise SystemExit("reference config has no data_glob/data_path") return infer_test_paths(cfg_base / train_ref) path = Path(ref_arg) text = str(path) candidates: List[Path] = [] if any(ch in text for ch in ["*", "?", "["]): parent = path.parent if path.parent != Path("") else Path(".") name = path.name if "train" in name: candidates.append(parent / name.replace("train", "test")) candidates.append(parent / name.replace("TRAIN", "TEST")) candidates.append(parent / "test*.csv.gz") candidates.append(parent / "test*.csv") else: parent = path.parent if path.parent != Path("") else Path(".") name = path.name if "train" in name: candidates.append(parent / name.replace("train", "test")) candidates.append(parent / name.replace("TRAIN", "TEST")) candidates.append(parent / "test*.csv.gz") candidates.append(parent / "test*.csv") for candidate in candidates: matches = expand_glob_or_file(candidate) if matches: return matches raise SystemExit(f"could not infer test files from reference: {ref_arg}") def expand_glob_or_file(path_like: str | Path) -> List[str]: path = Path(path_like) text = str(path) if any(ch in text for ch in ["*", "?", "["]): parent = path.parent if path.parent != Path("") else Path(".") matches = sorted(parent.glob(path.name)) return [str(p.resolve()) for p in matches] if path.exists(): return [str(path.resolve())] return [] def load_split_columns(split_path: str | Path) -> Tuple[str, List[str], List[str], List[str]]: split = load_json(split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_all = [c for c in split["discrete"] if c != time_col] label_cols = [c for c in disc_all if c.lower().startswith("attack")] disc_cols = [c for c in disc_all if c not in label_cols] return time_col, cont_cols, disc_cols, label_cols def load_vocab(vocab_path: str | Path, disc_cols: Sequence[str]) -> Tuple[Dict[str, Dict[str, int]], List[int]]: vocab = load_json(vocab_path)["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] return vocab, vocab_sizes def load_stats_vectors(stats_path: str | Path, cont_cols: Sequence[str]) -> Tuple[np.ndarray, np.ndarray]: stats = load_json(stats_path) mean = stats.get("raw_mean", stats["mean"]) std = stats.get("raw_std", stats["std"]) mean_vec = np.asarray([float(mean[c]) for c in cont_cols], dtype=np.float32) std_vec = np.asarray([max(float(std[c]), 1e-6) for c in cont_cols], dtype=np.float32) return mean_vec, std_vec def load_rows( path: str | Path, cont_cols: Sequence[str], disc_cols: Sequence[str], label_cols: Optional[Sequence[str]] = None, vocab: Optional[Dict[str, Dict[str, int]]] = None, max_rows: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: cont_rows: List[List[float]] = [] disc_rows: List[List[int]] = [] label_rows: List[int] = [] label_cols = list(label_cols or []) with open_csv(path) as f: reader = csv.DictReader(f) for idx, row in enumerate(reader): cont_rows.append([float(row[c]) for c in cont_cols]) if disc_cols: if vocab is None: disc_rows.append([int(float(row[c])) for c in disc_cols]) else: encoded = [] for c in disc_cols: mapping = vocab[c] encoded.append(mapping.get(row.get(c, ""), mapping.get("", 0))) disc_rows.append(encoded) if label_cols: label = 0 for c in label_cols: try: label = max(label, int(float(row.get(c, 0) or 0))) except Exception: continue label_rows.append(label) if max_rows is not None and idx + 1 >= max_rows: break cont = np.asarray(cont_rows, dtype=np.float32) if cont_rows else np.zeros((0, len(cont_cols)), dtype=np.float32) disc = np.asarray(disc_rows, dtype=np.int64) if disc_rows else np.zeros((0, len(disc_cols)), dtype=np.int64) labels = None if label_cols: labels = np.asarray(label_rows, dtype=np.int64) return cont, disc, labels def window_array( cont: np.ndarray, disc: np.ndarray, labels: Optional[np.ndarray], seq_len: int, stride: Optional[int] = None, max_windows: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: if stride is None or stride <= 0: stride = seq_len cont_windows: List[np.ndarray] = [] disc_windows: List[np.ndarray] = [] label_windows: List[int] = [] if cont.shape[0] < seq_len: return ( np.zeros((0, seq_len, cont.shape[1]), dtype=np.float32), np.zeros((0, seq_len, disc.shape[1]), dtype=np.int64), np.zeros((0,), dtype=np.int64) if labels is not None else None, ) count = 0 for start in range(0, cont.shape[0] - seq_len + 1, stride): end = start + seq_len cont_windows.append(cont[start:end]) disc_windows.append(disc[start:end]) if labels is not None: label_windows.append(int(labels[start:end].max())) count += 1 if max_windows is not None and count >= max_windows: break cont_out = np.asarray(cont_windows, dtype=np.float32) disc_out = np.asarray(disc_windows, dtype=np.int64) label_out = np.asarray(label_windows, dtype=np.int64) if labels is not None else None return cont_out, disc_out, label_out def load_windows_from_paths( paths: Sequence[str], cont_cols: Sequence[str], disc_cols: Sequence[str], seq_len: int, vocab: Optional[Dict[str, Dict[str, int]]] = None, label_cols: Optional[Sequence[str]] = None, stride: Optional[int] = None, max_windows: Optional[int] = None, max_rows_per_file: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: cont_all: List[np.ndarray] = [] disc_all: List[np.ndarray] = [] label_all: List[np.ndarray] = [] total = 0 for path in paths: remaining = None if max_windows is None else max(0, max_windows - total) if remaining == 0: break cont, disc, labels = load_rows( path, cont_cols, disc_cols, label_cols=label_cols, vocab=vocab, max_rows=max_rows_per_file, ) w_cont, w_disc, w_labels = window_array( cont, disc, labels, seq_len=seq_len, stride=stride, max_windows=remaining, ) if w_cont.size == 0: continue cont_all.append(w_cont) disc_all.append(w_disc) if w_labels is not None: label_all.append(w_labels) total += w_cont.shape[0] if not cont_all: empty_cont = np.zeros((0, seq_len, len(cont_cols)), dtype=np.float32) empty_disc = np.zeros((0, seq_len, len(disc_cols)), dtype=np.int64) empty_labels = np.zeros((0,), dtype=np.int64) if label_cols else None return empty_cont, empty_disc, empty_labels cont_out = np.concatenate(cont_all, axis=0) disc_out = np.concatenate(disc_all, axis=0) label_out = np.concatenate(label_all, axis=0) if label_all else None return cont_out, disc_out, label_out def filter_windows_by_label( cont_windows: np.ndarray, disc_windows: np.ndarray, labels: Optional[np.ndarray], target_label: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if labels is None: raise ValueError("labels are required for label filtering") mask = labels == int(target_label) return cont_windows[mask], disc_windows[mask], labels[mask] def standardize_cont_windows( cont_windows: np.ndarray, mean_vec: np.ndarray, std_vec: np.ndarray, ) -> np.ndarray: return (cont_windows - mean_vec.reshape(1, 1, -1)) / std_vec.reshape(1, 1, -1) def build_flat_window_vectors( cont_windows: np.ndarray, disc_windows: np.ndarray, mean_vec: np.ndarray, std_vec: np.ndarray, vocab_sizes: Sequence[int], ) -> np.ndarray: cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1) if disc_windows.size == 0: return cont_norm.astype(np.float32) disc_scale = np.asarray([max(v - 1, 1) for v in vocab_sizes], dtype=np.float32).reshape(1, 1, -1) disc_norm = (disc_windows.astype(np.float32) / disc_scale).reshape(disc_windows.shape[0], -1) return np.concatenate([cont_norm, disc_norm], axis=1).astype(np.float32) def build_histogram_embeddings( cont_windows: np.ndarray, disc_windows: np.ndarray, mean_vec: np.ndarray, std_vec: np.ndarray, vocab_sizes: Sequence[int], ) -> np.ndarray: cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1) if disc_windows.size == 0: return cont_norm.astype(np.float32) hist_features: List[np.ndarray] = [] for disc_idx, vocab_size in enumerate(vocab_sizes): one_hist = np.zeros((disc_windows.shape[0], vocab_size), dtype=np.float32) col_values = disc_windows[:, :, disc_idx] for value in range(vocab_size): one_hist[:, value] = (col_values == value).mean(axis=1) hist_features.append(one_hist) disc_hist = np.concatenate(hist_features, axis=1) if hist_features else np.zeros((cont_windows.shape[0], 0), dtype=np.float32) return np.concatenate([cont_norm, disc_hist], axis=1).astype(np.float32) def sample_indices(n_items: int, max_items: Optional[int], seed: int) -> np.ndarray: if max_items is None or n_items <= max_items: return np.arange(n_items, dtype=np.int64) rng = np.random.default_rng(seed) return np.sort(rng.choice(n_items, size=max_items, replace=False)) def subset_by_indices(array: np.ndarray, indices: np.ndarray) -> np.ndarray: if array is None: return array return array[indices] def compute_corr_matrix(rows: np.ndarray) -> np.ndarray: if rows.shape[0] < 2: return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32) matrix = np.corrcoef(rows, rowvar=False) matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0) return matrix.astype(np.float32) def compute_lagged_corr_matrix(rows: np.ndarray, lag: int = 1) -> np.ndarray: if rows.shape[0] <= lag: return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32) x = rows[:-lag] y = rows[lag:] x = x - x.mean(axis=0, keepdims=True) y = y - y.mean(axis=0, keepdims=True) cov = x.T @ y / max(x.shape[0] - 1, 1) std_x = np.sqrt(np.maximum((x ** 2).sum(axis=0) / max(x.shape[0] - 1, 1), 1e-8)) std_y = np.sqrt(np.maximum((y ** 2).sum(axis=0) / max(y.shape[0] - 1, 1), 1e-8)) denom = np.outer(std_x, std_y) corr = cov / np.maximum(denom, 1e-8) return np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def split_process_groups(feature_names: Sequence[str]) -> Dict[str, List[int]]: groups: Dict[str, List[int]] = {} for idx, name in enumerate(feature_names): prefix = name.split("_", 1)[0] groups.setdefault(prefix, []).append(idx) return groups def compute_average_psd(cont_windows: np.ndarray) -> np.ndarray: if cont_windows.size == 0: return np.zeros((0, 0), dtype=np.float32) centered = cont_windows - cont_windows.mean(axis=1, keepdims=True) spectrum = np.fft.rfft(centered, axis=1) power = (np.abs(spectrum) ** 2).mean(axis=0).T power = power.astype(np.float32) norm = power.sum(axis=1, keepdims=True) norm[norm <= 0] = 1.0 return power / norm def psd_distance_stats(real_psd: np.ndarray, gen_psd: np.ndarray) -> Dict[str, float]: if real_psd.size == 0 or gen_psd.size == 0: return { "avg_psd_l1": float("nan"), "avg_psd_cosine": float("nan"), "avg_low_high_ratio_abs_diff": float("nan"), } l1 = np.abs(real_psd - gen_psd).mean(axis=1) cosine = [] ratio_diffs = [] n_freq = real_psd.shape[1] split = max(1, n_freq // 4) for i in range(real_psd.shape[0]): rv = real_psd[i] gv = gen_psd[i] denom = max(np.linalg.norm(rv) * np.linalg.norm(gv), 1e-8) cosine.append(1.0 - float(np.dot(rv, gv) / denom)) r_low = float(rv[:split].sum()) r_high = float(rv[split:].sum()) g_low = float(gv[:split].sum()) g_high = float(gv[split:].sum()) r_ratio = r_low / max(r_high, 1e-8) g_ratio = g_low / max(g_high, 1e-8) ratio_diffs.append(abs(r_ratio - g_ratio)) return { "avg_psd_l1": float(np.mean(l1)), "avg_psd_cosine": float(np.mean(cosine)), "avg_low_high_ratio_abs_diff": float(np.mean(ratio_diffs)), } def pairwise_sq_dists(x: np.ndarray, y: np.ndarray) -> np.ndarray: x_norm = (x ** 2).sum(axis=1, keepdims=True) y_norm = (y ** 2).sum(axis=1, keepdims=True).T d2 = x_norm + y_norm - 2.0 * (x @ y.T) return np.maximum(d2, 0.0) def median_heuristic_gamma(x: np.ndarray, y: np.ndarray) -> float: joined = np.concatenate([x, y], axis=0) if joined.shape[0] <= 1: return 1.0 d2 = pairwise_sq_dists(joined[: min(joined.shape[0], 256)], joined[: min(joined.shape[0], 256)]) upper = d2[np.triu_indices_from(d2, k=1)] upper = upper[upper > 0] if upper.size == 0: return 1.0 median = float(np.median(upper)) return 1.0 / max(2.0 * median, 1e-8) def rbf_mmd(x: np.ndarray, y: np.ndarray, gamma: Optional[float] = None) -> Tuple[float, float]: if x.shape[0] == 0 or y.shape[0] == 0: return float("nan"), 1.0 if gamma is None: gamma = median_heuristic_gamma(x, y) k_xx = np.exp(-gamma * pairwise_sq_dists(x, x)) k_yy = np.exp(-gamma * pairwise_sq_dists(y, y)) k_xy = np.exp(-gamma * pairwise_sq_dists(x, y)) m = max(x.shape[0], 1) n = max(y.shape[0], 1) if m > 1: term_xx = (k_xx.sum() - np.trace(k_xx)) / (m * (m - 1)) else: term_xx = 0.0 if n > 1: term_yy = (k_yy.sum() - np.trace(k_yy)) / (n * (n - 1)) else: term_yy = 0.0 term_xy = 2.0 * k_xy.mean() return float(term_xx + term_yy - term_xy), float(gamma) def duplicate_rate(vectors: np.ndarray, decimals: int = 5) -> float: if vectors.shape[0] <= 1: return 0.0 rounded = np.round(vectors, decimals=decimals) unique = np.unique(rounded, axis=0).shape[0] return float(1.0 - unique / vectors.shape[0]) def exact_match_rate(query: np.ndarray, base: np.ndarray, decimals: int = 5) -> float: if query.shape[0] == 0 or base.shape[0] == 0: return 0.0 rounded_base = {tuple(row.tolist()) for row in np.round(base, decimals=decimals)} rounded_query = np.round(query, decimals=decimals) matches = sum(1 for row in rounded_query if tuple(row.tolist()) in rounded_base) return float(matches / query.shape[0]) def nearest_neighbor_distance_stats(query: np.ndarray, base: np.ndarray, batch_size: int = 128) -> Dict[str, float]: if query.shape[0] == 0 or base.shape[0] == 0: return {"mean": float("nan"), "median": float("nan"), "min": float("nan")} dists: List[np.ndarray] = [] for start in range(0, query.shape[0], batch_size): end = start + batch_size chunk = query[start:end] d2 = pairwise_sq_dists(chunk, base) dists.append(np.sqrt(np.min(d2, axis=1))) values = np.concatenate(dists, axis=0) return { "mean": float(values.mean()), "median": float(np.median(values)), "min": float(values.min()), } def one_nn_two_sample_accuracy(real_vecs: np.ndarray, gen_vecs: np.ndarray) -> float: if real_vecs.shape[0] < 2 or gen_vecs.shape[0] < 2: return float("nan") x = np.concatenate([real_vecs, gen_vecs], axis=0) y = np.concatenate( [ np.zeros(real_vecs.shape[0], dtype=np.int64), np.ones(gen_vecs.shape[0], dtype=np.int64), ], axis=0, ) d2 = pairwise_sq_dists(x, x) np.fill_diagonal(d2, np.inf) nn = np.argmin(d2, axis=1) pred = y[nn] return float((pred == y).mean()) def binary_classification_curves(y_true: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: order = np.argsort(-scores) y = y_true[order].astype(np.int64) scores_sorted = scores[order] tp = np.cumsum(y == 1) fp = np.cumsum(y == 0) positives = max(int((y_true == 1).sum()), 1) negatives = max(int((y_true == 0).sum()), 1) tpr = tp / positives fpr = fp / negatives precision = tp / np.maximum(tp + fp, 1) recall = tpr return scores_sorted, fpr, tpr, precision def binary_auroc(y_true: np.ndarray, scores: np.ndarray) -> float: if len(np.unique(y_true)) < 2: return float("nan") _, fpr, tpr, _ = binary_classification_curves(y_true, scores) fpr = np.concatenate([[0.0], fpr, [1.0]]) tpr = np.concatenate([[0.0], tpr, [1.0]]) return float(np.trapz(tpr, fpr)) def binary_average_precision(y_true: np.ndarray, scores: np.ndarray) -> float: if len(np.unique(y_true)) < 2: return float("nan") _, _, _, precision = binary_classification_curves(y_true, scores) positives = max(int((y_true == 1).sum()), 1) order = np.argsort(-scores) y = y_true[order].astype(np.int64) tp = np.cumsum(y == 1) recall = tp / positives precision = tp / np.arange(1, len(tp) + 1) recall = np.concatenate([[0.0], recall]) precision = np.concatenate([[precision[0] if precision.size else 1.0], precision]) return float(np.sum((recall[1:] - recall[:-1]) * precision[1:])) def binary_f1_at_threshold(y_true: np.ndarray, scores: np.ndarray, threshold: float) -> Dict[str, float]: pred = (scores >= threshold).astype(np.int64) tp = int(((pred == 1) & (y_true == 1)).sum()) fp = int(((pred == 1) & (y_true == 0)).sum()) fn = int(((pred == 0) & (y_true == 1)).sum()) precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) if precision + recall == 0: f1 = 0.0 else: f1 = 2.0 * precision * recall / (precision + recall) return {"threshold": float(threshold), "precision": float(precision), "recall": float(recall), "f1": float(f1)} def best_binary_f1(y_true: np.ndarray, scores: np.ndarray) -> Dict[str, float]: if y_true.size == 0: return {"threshold": float("nan"), "precision": float("nan"), "recall": float("nan"), "f1": float("nan")} thresholds = np.unique(scores) best = {"threshold": float(thresholds[0]), "precision": 0.0, "recall": 0.0, "f1": -1.0} for threshold in thresholds: stats = binary_f1_at_threshold(y_true, scores, float(threshold)) if stats["f1"] > best["f1"]: best = stats return best def set_random_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed)