Add comprehensive evaluation and ablation runner
This commit is contained in:
568
example/eval_utils.py
Normal file
568
example/eval_utils.py
Normal file
@@ -0,0 +1,568 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Shared utilities for evaluation, downstream utility, and ablations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_json(path: str | Path) -> Dict:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def open_csv(path: str | Path):
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
return gzip.open(path, "rt", newline="")
|
||||
return open(path, "r", newline="")
|
||||
|
||||
|
||||
def resolve_path(base_dir: Path, path_like: str | Path) -> Path:
|
||||
path = Path(path_like)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return (base_dir / path).resolve()
|
||||
|
||||
|
||||
def resolve_reference_paths(ref_arg: str | Path) -> List[str]:
|
||||
ref_path = Path(ref_arg)
|
||||
if ref_path.suffix == ".json" and ref_path.exists():
|
||||
cfg = load_json(ref_path)
|
||||
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if not data_glob:
|
||||
raise SystemExit("reference config has no data_glob/data_path")
|
||||
combined = ref_path.parent / data_glob
|
||||
return expand_glob_or_file(combined)
|
||||
return expand_glob_or_file(ref_path)
|
||||
|
||||
|
||||
def infer_test_paths(ref_arg: str | Path) -> List[str]:
|
||||
ref_path = Path(ref_arg)
|
||||
if ref_path.suffix == ".json" and ref_path.exists():
|
||||
cfg = load_json(ref_path)
|
||||
cfg_base = ref_path.parent
|
||||
explicit = cfg.get("test_glob") or cfg.get("test_path") or ""
|
||||
if explicit:
|
||||
return expand_glob_or_file(cfg_base / explicit)
|
||||
train_ref = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if not train_ref:
|
||||
raise SystemExit("reference config has no data_glob/data_path")
|
||||
return infer_test_paths(cfg_base / train_ref)
|
||||
|
||||
path = Path(ref_arg)
|
||||
text = str(path)
|
||||
candidates: List[Path] = []
|
||||
if any(ch in text for ch in ["*", "?", "["]):
|
||||
parent = path.parent if path.parent != Path("") else Path(".")
|
||||
name = path.name
|
||||
if "train" in name:
|
||||
candidates.append(parent / name.replace("train", "test"))
|
||||
candidates.append(parent / name.replace("TRAIN", "TEST"))
|
||||
candidates.append(parent / "test*.csv.gz")
|
||||
candidates.append(parent / "test*.csv")
|
||||
else:
|
||||
parent = path.parent if path.parent != Path("") else Path(".")
|
||||
name = path.name
|
||||
if "train" in name:
|
||||
candidates.append(parent / name.replace("train", "test"))
|
||||
candidates.append(parent / name.replace("TRAIN", "TEST"))
|
||||
candidates.append(parent / "test*.csv.gz")
|
||||
candidates.append(parent / "test*.csv")
|
||||
|
||||
for candidate in candidates:
|
||||
matches = expand_glob_or_file(candidate)
|
||||
if matches:
|
||||
return matches
|
||||
raise SystemExit(f"could not infer test files from reference: {ref_arg}")
|
||||
|
||||
|
||||
def expand_glob_or_file(path_like: str | Path) -> List[str]:
|
||||
path = Path(path_like)
|
||||
text = str(path)
|
||||
if any(ch in text for ch in ["*", "?", "["]):
|
||||
parent = path.parent if path.parent != Path("") else Path(".")
|
||||
matches = sorted(parent.glob(path.name))
|
||||
return [str(p.resolve()) for p in matches]
|
||||
if path.exists():
|
||||
return [str(path.resolve())]
|
||||
return []
|
||||
|
||||
|
||||
def load_split_columns(split_path: str | Path) -> Tuple[str, List[str], List[str], List[str]]:
|
||||
split = load_json(split_path)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_all = [c for c in split["discrete"] if c != time_col]
|
||||
label_cols = [c for c in disc_all if c.lower().startswith("attack")]
|
||||
disc_cols = [c for c in disc_all if c not in label_cols]
|
||||
return time_col, cont_cols, disc_cols, label_cols
|
||||
|
||||
|
||||
def load_vocab(vocab_path: str | Path, disc_cols: Sequence[str]) -> Tuple[Dict[str, Dict[str, int]], List[int]]:
|
||||
vocab = load_json(vocab_path)["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
return vocab, vocab_sizes
|
||||
|
||||
|
||||
def load_stats_vectors(stats_path: str | Path, cont_cols: Sequence[str]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
stats = load_json(stats_path)
|
||||
mean = stats.get("raw_mean", stats["mean"])
|
||||
std = stats.get("raw_std", stats["std"])
|
||||
mean_vec = np.asarray([float(mean[c]) for c in cont_cols], dtype=np.float32)
|
||||
std_vec = np.asarray([max(float(std[c]), 1e-6) for c in cont_cols], dtype=np.float32)
|
||||
return mean_vec, std_vec
|
||||
|
||||
|
||||
def load_rows(
|
||||
path: str | Path,
|
||||
cont_cols: Sequence[str],
|
||||
disc_cols: Sequence[str],
|
||||
label_cols: Optional[Sequence[str]] = None,
|
||||
vocab: Optional[Dict[str, Dict[str, int]]] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||
cont_rows: List[List[float]] = []
|
||||
disc_rows: List[List[int]] = []
|
||||
label_rows: List[int] = []
|
||||
label_cols = list(label_cols or [])
|
||||
|
||||
with open_csv(path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for idx, row in enumerate(reader):
|
||||
cont_rows.append([float(row[c]) for c in cont_cols])
|
||||
if disc_cols:
|
||||
if vocab is None:
|
||||
disc_rows.append([int(float(row[c])) for c in disc_cols])
|
||||
else:
|
||||
encoded = []
|
||||
for c in disc_cols:
|
||||
mapping = vocab[c]
|
||||
encoded.append(mapping.get(row.get(c, ""), mapping.get("<UNK>", 0)))
|
||||
disc_rows.append(encoded)
|
||||
if label_cols:
|
||||
label = 0
|
||||
for c in label_cols:
|
||||
try:
|
||||
label = max(label, int(float(row.get(c, 0) or 0)))
|
||||
except Exception:
|
||||
continue
|
||||
label_rows.append(label)
|
||||
if max_rows is not None and idx + 1 >= max_rows:
|
||||
break
|
||||
|
||||
cont = np.asarray(cont_rows, dtype=np.float32) if cont_rows else np.zeros((0, len(cont_cols)), dtype=np.float32)
|
||||
disc = np.asarray(disc_rows, dtype=np.int64) if disc_rows else np.zeros((0, len(disc_cols)), dtype=np.int64)
|
||||
labels = None
|
||||
if label_cols:
|
||||
labels = np.asarray(label_rows, dtype=np.int64)
|
||||
return cont, disc, labels
|
||||
|
||||
|
||||
def window_array(
|
||||
cont: np.ndarray,
|
||||
disc: np.ndarray,
|
||||
labels: Optional[np.ndarray],
|
||||
seq_len: int,
|
||||
stride: Optional[int] = None,
|
||||
max_windows: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||
if stride is None or stride <= 0:
|
||||
stride = seq_len
|
||||
cont_windows: List[np.ndarray] = []
|
||||
disc_windows: List[np.ndarray] = []
|
||||
label_windows: List[int] = []
|
||||
if cont.shape[0] < seq_len:
|
||||
return (
|
||||
np.zeros((0, seq_len, cont.shape[1]), dtype=np.float32),
|
||||
np.zeros((0, seq_len, disc.shape[1]), dtype=np.int64),
|
||||
np.zeros((0,), dtype=np.int64) if labels is not None else None,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for start in range(0, cont.shape[0] - seq_len + 1, stride):
|
||||
end = start + seq_len
|
||||
cont_windows.append(cont[start:end])
|
||||
disc_windows.append(disc[start:end])
|
||||
if labels is not None:
|
||||
label_windows.append(int(labels[start:end].max()))
|
||||
count += 1
|
||||
if max_windows is not None and count >= max_windows:
|
||||
break
|
||||
|
||||
cont_out = np.asarray(cont_windows, dtype=np.float32)
|
||||
disc_out = np.asarray(disc_windows, dtype=np.int64)
|
||||
label_out = np.asarray(label_windows, dtype=np.int64) if labels is not None else None
|
||||
return cont_out, disc_out, label_out
|
||||
|
||||
|
||||
def load_windows_from_paths(
|
||||
paths: Sequence[str],
|
||||
cont_cols: Sequence[str],
|
||||
disc_cols: Sequence[str],
|
||||
seq_len: int,
|
||||
vocab: Optional[Dict[str, Dict[str, int]]] = None,
|
||||
label_cols: Optional[Sequence[str]] = None,
|
||||
stride: Optional[int] = None,
|
||||
max_windows: Optional[int] = None,
|
||||
max_rows_per_file: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||
cont_all: List[np.ndarray] = []
|
||||
disc_all: List[np.ndarray] = []
|
||||
label_all: List[np.ndarray] = []
|
||||
total = 0
|
||||
|
||||
for path in paths:
|
||||
remaining = None if max_windows is None else max(0, max_windows - total)
|
||||
if remaining == 0:
|
||||
break
|
||||
cont, disc, labels = load_rows(
|
||||
path,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
label_cols=label_cols,
|
||||
vocab=vocab,
|
||||
max_rows=max_rows_per_file,
|
||||
)
|
||||
w_cont, w_disc, w_labels = window_array(
|
||||
cont,
|
||||
disc,
|
||||
labels,
|
||||
seq_len=seq_len,
|
||||
stride=stride,
|
||||
max_windows=remaining,
|
||||
)
|
||||
if w_cont.size == 0:
|
||||
continue
|
||||
cont_all.append(w_cont)
|
||||
disc_all.append(w_disc)
|
||||
if w_labels is not None:
|
||||
label_all.append(w_labels)
|
||||
total += w_cont.shape[0]
|
||||
|
||||
if not cont_all:
|
||||
empty_cont = np.zeros((0, seq_len, len(cont_cols)), dtype=np.float32)
|
||||
empty_disc = np.zeros((0, seq_len, len(disc_cols)), dtype=np.int64)
|
||||
empty_labels = np.zeros((0,), dtype=np.int64) if label_cols else None
|
||||
return empty_cont, empty_disc, empty_labels
|
||||
|
||||
cont_out = np.concatenate(cont_all, axis=0)
|
||||
disc_out = np.concatenate(disc_all, axis=0)
|
||||
label_out = np.concatenate(label_all, axis=0) if label_all else None
|
||||
return cont_out, disc_out, label_out
|
||||
|
||||
|
||||
def filter_windows_by_label(
|
||||
cont_windows: np.ndarray,
|
||||
disc_windows: np.ndarray,
|
||||
labels: Optional[np.ndarray],
|
||||
target_label: int,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
if labels is None:
|
||||
raise ValueError("labels are required for label filtering")
|
||||
mask = labels == int(target_label)
|
||||
return cont_windows[mask], disc_windows[mask], labels[mask]
|
||||
|
||||
|
||||
def standardize_cont_windows(
|
||||
cont_windows: np.ndarray,
|
||||
mean_vec: np.ndarray,
|
||||
std_vec: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
return (cont_windows - mean_vec.reshape(1, 1, -1)) / std_vec.reshape(1, 1, -1)
|
||||
|
||||
|
||||
def build_flat_window_vectors(
|
||||
cont_windows: np.ndarray,
|
||||
disc_windows: np.ndarray,
|
||||
mean_vec: np.ndarray,
|
||||
std_vec: np.ndarray,
|
||||
vocab_sizes: Sequence[int],
|
||||
) -> np.ndarray:
|
||||
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
|
||||
if disc_windows.size == 0:
|
||||
return cont_norm.astype(np.float32)
|
||||
disc_scale = np.asarray([max(v - 1, 1) for v in vocab_sizes], dtype=np.float32).reshape(1, 1, -1)
|
||||
disc_norm = (disc_windows.astype(np.float32) / disc_scale).reshape(disc_windows.shape[0], -1)
|
||||
return np.concatenate([cont_norm, disc_norm], axis=1).astype(np.float32)
|
||||
|
||||
|
||||
def build_histogram_embeddings(
|
||||
cont_windows: np.ndarray,
|
||||
disc_windows: np.ndarray,
|
||||
mean_vec: np.ndarray,
|
||||
std_vec: np.ndarray,
|
||||
vocab_sizes: Sequence[int],
|
||||
) -> np.ndarray:
|
||||
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
|
||||
if disc_windows.size == 0:
|
||||
return cont_norm.astype(np.float32)
|
||||
hist_features: List[np.ndarray] = []
|
||||
for disc_idx, vocab_size in enumerate(vocab_sizes):
|
||||
one_hist = np.zeros((disc_windows.shape[0], vocab_size), dtype=np.float32)
|
||||
col_values = disc_windows[:, :, disc_idx]
|
||||
for value in range(vocab_size):
|
||||
one_hist[:, value] = (col_values == value).mean(axis=1)
|
||||
hist_features.append(one_hist)
|
||||
disc_hist = np.concatenate(hist_features, axis=1) if hist_features else np.zeros((cont_windows.shape[0], 0), dtype=np.float32)
|
||||
return np.concatenate([cont_norm, disc_hist], axis=1).astype(np.float32)
|
||||
|
||||
|
||||
def sample_indices(n_items: int, max_items: Optional[int], seed: int) -> np.ndarray:
|
||||
if max_items is None or n_items <= max_items:
|
||||
return np.arange(n_items, dtype=np.int64)
|
||||
rng = np.random.default_rng(seed)
|
||||
return np.sort(rng.choice(n_items, size=max_items, replace=False))
|
||||
|
||||
|
||||
def subset_by_indices(array: np.ndarray, indices: np.ndarray) -> np.ndarray:
|
||||
if array is None:
|
||||
return array
|
||||
return array[indices]
|
||||
|
||||
|
||||
def compute_corr_matrix(rows: np.ndarray) -> np.ndarray:
|
||||
if rows.shape[0] < 2:
|
||||
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
|
||||
matrix = np.corrcoef(rows, rowvar=False)
|
||||
matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
return matrix.astype(np.float32)
|
||||
|
||||
|
||||
def compute_lagged_corr_matrix(rows: np.ndarray, lag: int = 1) -> np.ndarray:
|
||||
if rows.shape[0] <= lag:
|
||||
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
|
||||
x = rows[:-lag]
|
||||
y = rows[lag:]
|
||||
x = x - x.mean(axis=0, keepdims=True)
|
||||
y = y - y.mean(axis=0, keepdims=True)
|
||||
cov = x.T @ y / max(x.shape[0] - 1, 1)
|
||||
std_x = np.sqrt(np.maximum((x ** 2).sum(axis=0) / max(x.shape[0] - 1, 1), 1e-8))
|
||||
std_y = np.sqrt(np.maximum((y ** 2).sum(axis=0) / max(y.shape[0] - 1, 1), 1e-8))
|
||||
denom = np.outer(std_x, std_y)
|
||||
corr = cov / np.maximum(denom, 1e-8)
|
||||
return np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||
|
||||
|
||||
def split_process_groups(feature_names: Sequence[str]) -> Dict[str, List[int]]:
|
||||
groups: Dict[str, List[int]] = {}
|
||||
for idx, name in enumerate(feature_names):
|
||||
prefix = name.split("_", 1)[0]
|
||||
groups.setdefault(prefix, []).append(idx)
|
||||
return groups
|
||||
|
||||
|
||||
def compute_average_psd(cont_windows: np.ndarray) -> np.ndarray:
|
||||
if cont_windows.size == 0:
|
||||
return np.zeros((0, 0), dtype=np.float32)
|
||||
centered = cont_windows - cont_windows.mean(axis=1, keepdims=True)
|
||||
spectrum = np.fft.rfft(centered, axis=1)
|
||||
power = (np.abs(spectrum) ** 2).mean(axis=0).T
|
||||
power = power.astype(np.float32)
|
||||
norm = power.sum(axis=1, keepdims=True)
|
||||
norm[norm <= 0] = 1.0
|
||||
return power / norm
|
||||
|
||||
|
||||
def psd_distance_stats(real_psd: np.ndarray, gen_psd: np.ndarray) -> Dict[str, float]:
|
||||
if real_psd.size == 0 or gen_psd.size == 0:
|
||||
return {
|
||||
"avg_psd_l1": float("nan"),
|
||||
"avg_psd_cosine": float("nan"),
|
||||
"avg_low_high_ratio_abs_diff": float("nan"),
|
||||
}
|
||||
l1 = np.abs(real_psd - gen_psd).mean(axis=1)
|
||||
cosine = []
|
||||
ratio_diffs = []
|
||||
n_freq = real_psd.shape[1]
|
||||
split = max(1, n_freq // 4)
|
||||
for i in range(real_psd.shape[0]):
|
||||
rv = real_psd[i]
|
||||
gv = gen_psd[i]
|
||||
denom = max(np.linalg.norm(rv) * np.linalg.norm(gv), 1e-8)
|
||||
cosine.append(1.0 - float(np.dot(rv, gv) / denom))
|
||||
r_low = float(rv[:split].sum())
|
||||
r_high = float(rv[split:].sum())
|
||||
g_low = float(gv[:split].sum())
|
||||
g_high = float(gv[split:].sum())
|
||||
r_ratio = r_low / max(r_high, 1e-8)
|
||||
g_ratio = g_low / max(g_high, 1e-8)
|
||||
ratio_diffs.append(abs(r_ratio - g_ratio))
|
||||
return {
|
||||
"avg_psd_l1": float(np.mean(l1)),
|
||||
"avg_psd_cosine": float(np.mean(cosine)),
|
||||
"avg_low_high_ratio_abs_diff": float(np.mean(ratio_diffs)),
|
||||
}
|
||||
|
||||
|
||||
def pairwise_sq_dists(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
x_norm = (x ** 2).sum(axis=1, keepdims=True)
|
||||
y_norm = (y ** 2).sum(axis=1, keepdims=True).T
|
||||
d2 = x_norm + y_norm - 2.0 * (x @ y.T)
|
||||
return np.maximum(d2, 0.0)
|
||||
|
||||
|
||||
def median_heuristic_gamma(x: np.ndarray, y: np.ndarray) -> float:
|
||||
joined = np.concatenate([x, y], axis=0)
|
||||
if joined.shape[0] <= 1:
|
||||
return 1.0
|
||||
d2 = pairwise_sq_dists(joined[: min(joined.shape[0], 256)], joined[: min(joined.shape[0], 256)])
|
||||
upper = d2[np.triu_indices_from(d2, k=1)]
|
||||
upper = upper[upper > 0]
|
||||
if upper.size == 0:
|
||||
return 1.0
|
||||
median = float(np.median(upper))
|
||||
return 1.0 / max(2.0 * median, 1e-8)
|
||||
|
||||
|
||||
def rbf_mmd(x: np.ndarray, y: np.ndarray, gamma: Optional[float] = None) -> Tuple[float, float]:
|
||||
if x.shape[0] == 0 or y.shape[0] == 0:
|
||||
return float("nan"), 1.0
|
||||
if gamma is None:
|
||||
gamma = median_heuristic_gamma(x, y)
|
||||
k_xx = np.exp(-gamma * pairwise_sq_dists(x, x))
|
||||
k_yy = np.exp(-gamma * pairwise_sq_dists(y, y))
|
||||
k_xy = np.exp(-gamma * pairwise_sq_dists(x, y))
|
||||
m = max(x.shape[0], 1)
|
||||
n = max(y.shape[0], 1)
|
||||
if m > 1:
|
||||
term_xx = (k_xx.sum() - np.trace(k_xx)) / (m * (m - 1))
|
||||
else:
|
||||
term_xx = 0.0
|
||||
if n > 1:
|
||||
term_yy = (k_yy.sum() - np.trace(k_yy)) / (n * (n - 1))
|
||||
else:
|
||||
term_yy = 0.0
|
||||
term_xy = 2.0 * k_xy.mean()
|
||||
return float(term_xx + term_yy - term_xy), float(gamma)
|
||||
|
||||
|
||||
def duplicate_rate(vectors: np.ndarray, decimals: int = 5) -> float:
|
||||
if vectors.shape[0] <= 1:
|
||||
return 0.0
|
||||
rounded = np.round(vectors, decimals=decimals)
|
||||
unique = np.unique(rounded, axis=0).shape[0]
|
||||
return float(1.0 - unique / vectors.shape[0])
|
||||
|
||||
|
||||
def exact_match_rate(query: np.ndarray, base: np.ndarray, decimals: int = 5) -> float:
|
||||
if query.shape[0] == 0 or base.shape[0] == 0:
|
||||
return 0.0
|
||||
rounded_base = {tuple(row.tolist()) for row in np.round(base, decimals=decimals)}
|
||||
rounded_query = np.round(query, decimals=decimals)
|
||||
matches = sum(1 for row in rounded_query if tuple(row.tolist()) in rounded_base)
|
||||
return float(matches / query.shape[0])
|
||||
|
||||
|
||||
def nearest_neighbor_distance_stats(query: np.ndarray, base: np.ndarray, batch_size: int = 128) -> Dict[str, float]:
|
||||
if query.shape[0] == 0 or base.shape[0] == 0:
|
||||
return {"mean": float("nan"), "median": float("nan"), "min": float("nan")}
|
||||
dists: List[np.ndarray] = []
|
||||
for start in range(0, query.shape[0], batch_size):
|
||||
end = start + batch_size
|
||||
chunk = query[start:end]
|
||||
d2 = pairwise_sq_dists(chunk, base)
|
||||
dists.append(np.sqrt(np.min(d2, axis=1)))
|
||||
values = np.concatenate(dists, axis=0)
|
||||
return {
|
||||
"mean": float(values.mean()),
|
||||
"median": float(np.median(values)),
|
||||
"min": float(values.min()),
|
||||
}
|
||||
|
||||
|
||||
def one_nn_two_sample_accuracy(real_vecs: np.ndarray, gen_vecs: np.ndarray) -> float:
|
||||
if real_vecs.shape[0] < 2 or gen_vecs.shape[0] < 2:
|
||||
return float("nan")
|
||||
x = np.concatenate([real_vecs, gen_vecs], axis=0)
|
||||
y = np.concatenate(
|
||||
[
|
||||
np.zeros(real_vecs.shape[0], dtype=np.int64),
|
||||
np.ones(gen_vecs.shape[0], dtype=np.int64),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
d2 = pairwise_sq_dists(x, x)
|
||||
np.fill_diagonal(d2, np.inf)
|
||||
nn = np.argmin(d2, axis=1)
|
||||
pred = y[nn]
|
||||
return float((pred == y).mean())
|
||||
|
||||
|
||||
def binary_classification_curves(y_true: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
order = np.argsort(-scores)
|
||||
y = y_true[order].astype(np.int64)
|
||||
scores_sorted = scores[order]
|
||||
tp = np.cumsum(y == 1)
|
||||
fp = np.cumsum(y == 0)
|
||||
positives = max(int((y_true == 1).sum()), 1)
|
||||
negatives = max(int((y_true == 0).sum()), 1)
|
||||
tpr = tp / positives
|
||||
fpr = fp / negatives
|
||||
precision = tp / np.maximum(tp + fp, 1)
|
||||
recall = tpr
|
||||
return scores_sorted, fpr, tpr, precision
|
||||
|
||||
|
||||
def binary_auroc(y_true: np.ndarray, scores: np.ndarray) -> float:
|
||||
if len(np.unique(y_true)) < 2:
|
||||
return float("nan")
|
||||
_, fpr, tpr, _ = binary_classification_curves(y_true, scores)
|
||||
fpr = np.concatenate([[0.0], fpr, [1.0]])
|
||||
tpr = np.concatenate([[0.0], tpr, [1.0]])
|
||||
return float(np.trapz(tpr, fpr))
|
||||
|
||||
|
||||
def binary_average_precision(y_true: np.ndarray, scores: np.ndarray) -> float:
|
||||
if len(np.unique(y_true)) < 2:
|
||||
return float("nan")
|
||||
_, _, _, precision = binary_classification_curves(y_true, scores)
|
||||
positives = max(int((y_true == 1).sum()), 1)
|
||||
order = np.argsort(-scores)
|
||||
y = y_true[order].astype(np.int64)
|
||||
tp = np.cumsum(y == 1)
|
||||
recall = tp / positives
|
||||
precision = tp / np.arange(1, len(tp) + 1)
|
||||
recall = np.concatenate([[0.0], recall])
|
||||
precision = np.concatenate([[precision[0] if precision.size else 1.0], precision])
|
||||
return float(np.sum((recall[1:] - recall[:-1]) * precision[1:]))
|
||||
|
||||
|
||||
def binary_f1_at_threshold(y_true: np.ndarray, scores: np.ndarray, threshold: float) -> Dict[str, float]:
|
||||
pred = (scores >= threshold).astype(np.int64)
|
||||
tp = int(((pred == 1) & (y_true == 1)).sum())
|
||||
fp = int(((pred == 1) & (y_true == 0)).sum())
|
||||
fn = int(((pred == 0) & (y_true == 1)).sum())
|
||||
precision = tp / max(tp + fp, 1)
|
||||
recall = tp / max(tp + fn, 1)
|
||||
if precision + recall == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
f1 = 2.0 * precision * recall / (precision + recall)
|
||||
return {"threshold": float(threshold), "precision": float(precision), "recall": float(recall), "f1": float(f1)}
|
||||
|
||||
|
||||
def best_binary_f1(y_true: np.ndarray, scores: np.ndarray) -> Dict[str, float]:
|
||||
if y_true.size == 0:
|
||||
return {"threshold": float("nan"), "precision": float("nan"), "recall": float("nan"), "f1": float("nan")}
|
||||
thresholds = np.unique(scores)
|
||||
best = {"threshold": float(thresholds[0]), "precision": 0.0, "recall": 0.0, "f1": -1.0}
|
||||
for threshold in thresholds:
|
||||
stats = binary_f1_at_threshold(y_true, scores, float(threshold))
|
||||
if stats["f1"] > best["f1"]:
|
||||
best = stats
|
||||
return best
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user