Add comprehensive evaluation and ablation runner

This commit is contained in:
MZ YANG
2026-03-25 22:20:43 +08:00
parent f1afd4bf38
commit 957b010ea1
8 changed files with 1730 additions and 30 deletions

568
example/eval_utils.py Normal file
View File

@@ -0,0 +1,568 @@
#!/usr/bin/env python3
"""Shared utilities for evaluation, downstream utility, and ablations."""
from __future__ import annotations
import csv
import gzip
import json
import math
import random
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
def load_json(path: str | Path) -> Dict:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def open_csv(path: str | Path):
path = str(path)
if path.endswith(".gz"):
return gzip.open(path, "rt", newline="")
return open(path, "r", newline="")
def resolve_path(base_dir: Path, path_like: str | Path) -> Path:
path = Path(path_like)
if path.is_absolute():
return path
return (base_dir / path).resolve()
def resolve_reference_paths(ref_arg: str | Path) -> List[str]:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json" and ref_path.exists():
cfg = load_json(ref_path)
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
return expand_glob_or_file(combined)
return expand_glob_or_file(ref_path)
def infer_test_paths(ref_arg: str | Path) -> List[str]:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json" and ref_path.exists():
cfg = load_json(ref_path)
cfg_base = ref_path.parent
explicit = cfg.get("test_glob") or cfg.get("test_path") or ""
if explicit:
return expand_glob_or_file(cfg_base / explicit)
train_ref = cfg.get("data_glob") or cfg.get("data_path") or ""
if not train_ref:
raise SystemExit("reference config has no data_glob/data_path")
return infer_test_paths(cfg_base / train_ref)
path = Path(ref_arg)
text = str(path)
candidates: List[Path] = []
if any(ch in text for ch in ["*", "?", "["]):
parent = path.parent if path.parent != Path("") else Path(".")
name = path.name
if "train" in name:
candidates.append(parent / name.replace("train", "test"))
candidates.append(parent / name.replace("TRAIN", "TEST"))
candidates.append(parent / "test*.csv.gz")
candidates.append(parent / "test*.csv")
else:
parent = path.parent if path.parent != Path("") else Path(".")
name = path.name
if "train" in name:
candidates.append(parent / name.replace("train", "test"))
candidates.append(parent / name.replace("TRAIN", "TEST"))
candidates.append(parent / "test*.csv.gz")
candidates.append(parent / "test*.csv")
for candidate in candidates:
matches = expand_glob_or_file(candidate)
if matches:
return matches
raise SystemExit(f"could not infer test files from reference: {ref_arg}")
def expand_glob_or_file(path_like: str | Path) -> List[str]:
path = Path(path_like)
text = str(path)
if any(ch in text for ch in ["*", "?", "["]):
parent = path.parent if path.parent != Path("") else Path(".")
matches = sorted(parent.glob(path.name))
return [str(p.resolve()) for p in matches]
if path.exists():
return [str(path.resolve())]
return []
def load_split_columns(split_path: str | Path) -> Tuple[str, List[str], List[str], List[str]]:
split = load_json(split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_all = [c for c in split["discrete"] if c != time_col]
label_cols = [c for c in disc_all if c.lower().startswith("attack")]
disc_cols = [c for c in disc_all if c not in label_cols]
return time_col, cont_cols, disc_cols, label_cols
def load_vocab(vocab_path: str | Path, disc_cols: Sequence[str]) -> Tuple[Dict[str, Dict[str, int]], List[int]]:
vocab = load_json(vocab_path)["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
return vocab, vocab_sizes
def load_stats_vectors(stats_path: str | Path, cont_cols: Sequence[str]) -> Tuple[np.ndarray, np.ndarray]:
stats = load_json(stats_path)
mean = stats.get("raw_mean", stats["mean"])
std = stats.get("raw_std", stats["std"])
mean_vec = np.asarray([float(mean[c]) for c in cont_cols], dtype=np.float32)
std_vec = np.asarray([max(float(std[c]), 1e-6) for c in cont_cols], dtype=np.float32)
return mean_vec, std_vec
def load_rows(
path: str | Path,
cont_cols: Sequence[str],
disc_cols: Sequence[str],
label_cols: Optional[Sequence[str]] = None,
vocab: Optional[Dict[str, Dict[str, int]]] = None,
max_rows: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
cont_rows: List[List[float]] = []
disc_rows: List[List[int]] = []
label_rows: List[int] = []
label_cols = list(label_cols or [])
with open_csv(path) as f:
reader = csv.DictReader(f)
for idx, row in enumerate(reader):
cont_rows.append([float(row[c]) for c in cont_cols])
if disc_cols:
if vocab is None:
disc_rows.append([int(float(row[c])) for c in disc_cols])
else:
encoded = []
for c in disc_cols:
mapping = vocab[c]
encoded.append(mapping.get(row.get(c, ""), mapping.get("<UNK>", 0)))
disc_rows.append(encoded)
if label_cols:
label = 0
for c in label_cols:
try:
label = max(label, int(float(row.get(c, 0) or 0)))
except Exception:
continue
label_rows.append(label)
if max_rows is not None and idx + 1 >= max_rows:
break
cont = np.asarray(cont_rows, dtype=np.float32) if cont_rows else np.zeros((0, len(cont_cols)), dtype=np.float32)
disc = np.asarray(disc_rows, dtype=np.int64) if disc_rows else np.zeros((0, len(disc_cols)), dtype=np.int64)
labels = None
if label_cols:
labels = np.asarray(label_rows, dtype=np.int64)
return cont, disc, labels
def window_array(
cont: np.ndarray,
disc: np.ndarray,
labels: Optional[np.ndarray],
seq_len: int,
stride: Optional[int] = None,
max_windows: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
if stride is None or stride <= 0:
stride = seq_len
cont_windows: List[np.ndarray] = []
disc_windows: List[np.ndarray] = []
label_windows: List[int] = []
if cont.shape[0] < seq_len:
return (
np.zeros((0, seq_len, cont.shape[1]), dtype=np.float32),
np.zeros((0, seq_len, disc.shape[1]), dtype=np.int64),
np.zeros((0,), dtype=np.int64) if labels is not None else None,
)
count = 0
for start in range(0, cont.shape[0] - seq_len + 1, stride):
end = start + seq_len
cont_windows.append(cont[start:end])
disc_windows.append(disc[start:end])
if labels is not None:
label_windows.append(int(labels[start:end].max()))
count += 1
if max_windows is not None and count >= max_windows:
break
cont_out = np.asarray(cont_windows, dtype=np.float32)
disc_out = np.asarray(disc_windows, dtype=np.int64)
label_out = np.asarray(label_windows, dtype=np.int64) if labels is not None else None
return cont_out, disc_out, label_out
def load_windows_from_paths(
paths: Sequence[str],
cont_cols: Sequence[str],
disc_cols: Sequence[str],
seq_len: int,
vocab: Optional[Dict[str, Dict[str, int]]] = None,
label_cols: Optional[Sequence[str]] = None,
stride: Optional[int] = None,
max_windows: Optional[int] = None,
max_rows_per_file: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
cont_all: List[np.ndarray] = []
disc_all: List[np.ndarray] = []
label_all: List[np.ndarray] = []
total = 0
for path in paths:
remaining = None if max_windows is None else max(0, max_windows - total)
if remaining == 0:
break
cont, disc, labels = load_rows(
path,
cont_cols,
disc_cols,
label_cols=label_cols,
vocab=vocab,
max_rows=max_rows_per_file,
)
w_cont, w_disc, w_labels = window_array(
cont,
disc,
labels,
seq_len=seq_len,
stride=stride,
max_windows=remaining,
)
if w_cont.size == 0:
continue
cont_all.append(w_cont)
disc_all.append(w_disc)
if w_labels is not None:
label_all.append(w_labels)
total += w_cont.shape[0]
if not cont_all:
empty_cont = np.zeros((0, seq_len, len(cont_cols)), dtype=np.float32)
empty_disc = np.zeros((0, seq_len, len(disc_cols)), dtype=np.int64)
empty_labels = np.zeros((0,), dtype=np.int64) if label_cols else None
return empty_cont, empty_disc, empty_labels
cont_out = np.concatenate(cont_all, axis=0)
disc_out = np.concatenate(disc_all, axis=0)
label_out = np.concatenate(label_all, axis=0) if label_all else None
return cont_out, disc_out, label_out
def filter_windows_by_label(
cont_windows: np.ndarray,
disc_windows: np.ndarray,
labels: Optional[np.ndarray],
target_label: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if labels is None:
raise ValueError("labels are required for label filtering")
mask = labels == int(target_label)
return cont_windows[mask], disc_windows[mask], labels[mask]
def standardize_cont_windows(
cont_windows: np.ndarray,
mean_vec: np.ndarray,
std_vec: np.ndarray,
) -> np.ndarray:
return (cont_windows - mean_vec.reshape(1, 1, -1)) / std_vec.reshape(1, 1, -1)
def build_flat_window_vectors(
cont_windows: np.ndarray,
disc_windows: np.ndarray,
mean_vec: np.ndarray,
std_vec: np.ndarray,
vocab_sizes: Sequence[int],
) -> np.ndarray:
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
if disc_windows.size == 0:
return cont_norm.astype(np.float32)
disc_scale = np.asarray([max(v - 1, 1) for v in vocab_sizes], dtype=np.float32).reshape(1, 1, -1)
disc_norm = (disc_windows.astype(np.float32) / disc_scale).reshape(disc_windows.shape[0], -1)
return np.concatenate([cont_norm, disc_norm], axis=1).astype(np.float32)
def build_histogram_embeddings(
cont_windows: np.ndarray,
disc_windows: np.ndarray,
mean_vec: np.ndarray,
std_vec: np.ndarray,
vocab_sizes: Sequence[int],
) -> np.ndarray:
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
if disc_windows.size == 0:
return cont_norm.astype(np.float32)
hist_features: List[np.ndarray] = []
for disc_idx, vocab_size in enumerate(vocab_sizes):
one_hist = np.zeros((disc_windows.shape[0], vocab_size), dtype=np.float32)
col_values = disc_windows[:, :, disc_idx]
for value in range(vocab_size):
one_hist[:, value] = (col_values == value).mean(axis=1)
hist_features.append(one_hist)
disc_hist = np.concatenate(hist_features, axis=1) if hist_features else np.zeros((cont_windows.shape[0], 0), dtype=np.float32)
return np.concatenate([cont_norm, disc_hist], axis=1).astype(np.float32)
def sample_indices(n_items: int, max_items: Optional[int], seed: int) -> np.ndarray:
if max_items is None or n_items <= max_items:
return np.arange(n_items, dtype=np.int64)
rng = np.random.default_rng(seed)
return np.sort(rng.choice(n_items, size=max_items, replace=False))
def subset_by_indices(array: np.ndarray, indices: np.ndarray) -> np.ndarray:
if array is None:
return array
return array[indices]
def compute_corr_matrix(rows: np.ndarray) -> np.ndarray:
if rows.shape[0] < 2:
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
matrix = np.corrcoef(rows, rowvar=False)
matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0)
return matrix.astype(np.float32)
def compute_lagged_corr_matrix(rows: np.ndarray, lag: int = 1) -> np.ndarray:
if rows.shape[0] <= lag:
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
x = rows[:-lag]
y = rows[lag:]
x = x - x.mean(axis=0, keepdims=True)
y = y - y.mean(axis=0, keepdims=True)
cov = x.T @ y / max(x.shape[0] - 1, 1)
std_x = np.sqrt(np.maximum((x ** 2).sum(axis=0) / max(x.shape[0] - 1, 1), 1e-8))
std_y = np.sqrt(np.maximum((y ** 2).sum(axis=0) / max(y.shape[0] - 1, 1), 1e-8))
denom = np.outer(std_x, std_y)
corr = cov / np.maximum(denom, 1e-8)
return np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def split_process_groups(feature_names: Sequence[str]) -> Dict[str, List[int]]:
groups: Dict[str, List[int]] = {}
for idx, name in enumerate(feature_names):
prefix = name.split("_", 1)[0]
groups.setdefault(prefix, []).append(idx)
return groups
def compute_average_psd(cont_windows: np.ndarray) -> np.ndarray:
if cont_windows.size == 0:
return np.zeros((0, 0), dtype=np.float32)
centered = cont_windows - cont_windows.mean(axis=1, keepdims=True)
spectrum = np.fft.rfft(centered, axis=1)
power = (np.abs(spectrum) ** 2).mean(axis=0).T
power = power.astype(np.float32)
norm = power.sum(axis=1, keepdims=True)
norm[norm <= 0] = 1.0
return power / norm
def psd_distance_stats(real_psd: np.ndarray, gen_psd: np.ndarray) -> Dict[str, float]:
if real_psd.size == 0 or gen_psd.size == 0:
return {
"avg_psd_l1": float("nan"),
"avg_psd_cosine": float("nan"),
"avg_low_high_ratio_abs_diff": float("nan"),
}
l1 = np.abs(real_psd - gen_psd).mean(axis=1)
cosine = []
ratio_diffs = []
n_freq = real_psd.shape[1]
split = max(1, n_freq // 4)
for i in range(real_psd.shape[0]):
rv = real_psd[i]
gv = gen_psd[i]
denom = max(np.linalg.norm(rv) * np.linalg.norm(gv), 1e-8)
cosine.append(1.0 - float(np.dot(rv, gv) / denom))
r_low = float(rv[:split].sum())
r_high = float(rv[split:].sum())
g_low = float(gv[:split].sum())
g_high = float(gv[split:].sum())
r_ratio = r_low / max(r_high, 1e-8)
g_ratio = g_low / max(g_high, 1e-8)
ratio_diffs.append(abs(r_ratio - g_ratio))
return {
"avg_psd_l1": float(np.mean(l1)),
"avg_psd_cosine": float(np.mean(cosine)),
"avg_low_high_ratio_abs_diff": float(np.mean(ratio_diffs)),
}
def pairwise_sq_dists(x: np.ndarray, y: np.ndarray) -> np.ndarray:
x_norm = (x ** 2).sum(axis=1, keepdims=True)
y_norm = (y ** 2).sum(axis=1, keepdims=True).T
d2 = x_norm + y_norm - 2.0 * (x @ y.T)
return np.maximum(d2, 0.0)
def median_heuristic_gamma(x: np.ndarray, y: np.ndarray) -> float:
joined = np.concatenate([x, y], axis=0)
if joined.shape[0] <= 1:
return 1.0
d2 = pairwise_sq_dists(joined[: min(joined.shape[0], 256)], joined[: min(joined.shape[0], 256)])
upper = d2[np.triu_indices_from(d2, k=1)]
upper = upper[upper > 0]
if upper.size == 0:
return 1.0
median = float(np.median(upper))
return 1.0 / max(2.0 * median, 1e-8)
def rbf_mmd(x: np.ndarray, y: np.ndarray, gamma: Optional[float] = None) -> Tuple[float, float]:
if x.shape[0] == 0 or y.shape[0] == 0:
return float("nan"), 1.0
if gamma is None:
gamma = median_heuristic_gamma(x, y)
k_xx = np.exp(-gamma * pairwise_sq_dists(x, x))
k_yy = np.exp(-gamma * pairwise_sq_dists(y, y))
k_xy = np.exp(-gamma * pairwise_sq_dists(x, y))
m = max(x.shape[0], 1)
n = max(y.shape[0], 1)
if m > 1:
term_xx = (k_xx.sum() - np.trace(k_xx)) / (m * (m - 1))
else:
term_xx = 0.0
if n > 1:
term_yy = (k_yy.sum() - np.trace(k_yy)) / (n * (n - 1))
else:
term_yy = 0.0
term_xy = 2.0 * k_xy.mean()
return float(term_xx + term_yy - term_xy), float(gamma)
def duplicate_rate(vectors: np.ndarray, decimals: int = 5) -> float:
if vectors.shape[0] <= 1:
return 0.0
rounded = np.round(vectors, decimals=decimals)
unique = np.unique(rounded, axis=0).shape[0]
return float(1.0 - unique / vectors.shape[0])
def exact_match_rate(query: np.ndarray, base: np.ndarray, decimals: int = 5) -> float:
if query.shape[0] == 0 or base.shape[0] == 0:
return 0.0
rounded_base = {tuple(row.tolist()) for row in np.round(base, decimals=decimals)}
rounded_query = np.round(query, decimals=decimals)
matches = sum(1 for row in rounded_query if tuple(row.tolist()) in rounded_base)
return float(matches / query.shape[0])
def nearest_neighbor_distance_stats(query: np.ndarray, base: np.ndarray, batch_size: int = 128) -> Dict[str, float]:
if query.shape[0] == 0 or base.shape[0] == 0:
return {"mean": float("nan"), "median": float("nan"), "min": float("nan")}
dists: List[np.ndarray] = []
for start in range(0, query.shape[0], batch_size):
end = start + batch_size
chunk = query[start:end]
d2 = pairwise_sq_dists(chunk, base)
dists.append(np.sqrt(np.min(d2, axis=1)))
values = np.concatenate(dists, axis=0)
return {
"mean": float(values.mean()),
"median": float(np.median(values)),
"min": float(values.min()),
}
def one_nn_two_sample_accuracy(real_vecs: np.ndarray, gen_vecs: np.ndarray) -> float:
if real_vecs.shape[0] < 2 or gen_vecs.shape[0] < 2:
return float("nan")
x = np.concatenate([real_vecs, gen_vecs], axis=0)
y = np.concatenate(
[
np.zeros(real_vecs.shape[0], dtype=np.int64),
np.ones(gen_vecs.shape[0], dtype=np.int64),
],
axis=0,
)
d2 = pairwise_sq_dists(x, x)
np.fill_diagonal(d2, np.inf)
nn = np.argmin(d2, axis=1)
pred = y[nn]
return float((pred == y).mean())
def binary_classification_curves(y_true: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
order = np.argsort(-scores)
y = y_true[order].astype(np.int64)
scores_sorted = scores[order]
tp = np.cumsum(y == 1)
fp = np.cumsum(y == 0)
positives = max(int((y_true == 1).sum()), 1)
negatives = max(int((y_true == 0).sum()), 1)
tpr = tp / positives
fpr = fp / negatives
precision = tp / np.maximum(tp + fp, 1)
recall = tpr
return scores_sorted, fpr, tpr, precision
def binary_auroc(y_true: np.ndarray, scores: np.ndarray) -> float:
if len(np.unique(y_true)) < 2:
return float("nan")
_, fpr, tpr, _ = binary_classification_curves(y_true, scores)
fpr = np.concatenate([[0.0], fpr, [1.0]])
tpr = np.concatenate([[0.0], tpr, [1.0]])
return float(np.trapz(tpr, fpr))
def binary_average_precision(y_true: np.ndarray, scores: np.ndarray) -> float:
if len(np.unique(y_true)) < 2:
return float("nan")
_, _, _, precision = binary_classification_curves(y_true, scores)
positives = max(int((y_true == 1).sum()), 1)
order = np.argsort(-scores)
y = y_true[order].astype(np.int64)
tp = np.cumsum(y == 1)
recall = tp / positives
precision = tp / np.arange(1, len(tp) + 1)
recall = np.concatenate([[0.0], recall])
precision = np.concatenate([[precision[0] if precision.size else 1.0], precision])
return float(np.sum((recall[1:] - recall[:-1]) * precision[1:]))
def binary_f1_at_threshold(y_true: np.ndarray, scores: np.ndarray, threshold: float) -> Dict[str, float]:
pred = (scores >= threshold).astype(np.int64)
tp = int(((pred == 1) & (y_true == 1)).sum())
fp = int(((pred == 1) & (y_true == 0)).sum())
fn = int(((pred == 0) & (y_true == 1)).sum())
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
if precision + recall == 0:
f1 = 0.0
else:
f1 = 2.0 * precision * recall / (precision + recall)
return {"threshold": float(threshold), "precision": float(precision), "recall": float(recall), "f1": float(f1)}
def best_binary_f1(y_true: np.ndarray, scores: np.ndarray) -> Dict[str, float]:
if y_true.size == 0:
return {"threshold": float("nan"), "precision": float("nan"), "recall": float("nan"), "f1": float("nan")}
thresholds = np.unique(scores)
best = {"threshold": float(thresholds[0]), "precision": 0.0, "recall": 0.0, "f1": -1.0}
for threshold in thresholds:
stats = binary_f1_at_threshold(y_true, scores, float(threshold))
if stats["f1"] > best["f1"]:
best = stats
return best
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)