Add comprehensive evaluation and ablation runner
This commit is contained in:
568
example/eval_utils.py
Normal file
568
example/eval_utils.py
Normal file
@@ -0,0 +1,568 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Shared utilities for evaluation, downstream utility, and ablations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: str | Path) -> Dict:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def open_csv(path: str | Path):
|
||||||
|
path = str(path)
|
||||||
|
if path.endswith(".gz"):
|
||||||
|
return gzip.open(path, "rt", newline="")
|
||||||
|
return open(path, "r", newline="")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_path(base_dir: Path, path_like: str | Path) -> Path:
|
||||||
|
path = Path(path_like)
|
||||||
|
if path.is_absolute():
|
||||||
|
return path
|
||||||
|
return (base_dir / path).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_reference_paths(ref_arg: str | Path) -> List[str]:
|
||||||
|
ref_path = Path(ref_arg)
|
||||||
|
if ref_path.suffix == ".json" and ref_path.exists():
|
||||||
|
cfg = load_json(ref_path)
|
||||||
|
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||||
|
if not data_glob:
|
||||||
|
raise SystemExit("reference config has no data_glob/data_path")
|
||||||
|
combined = ref_path.parent / data_glob
|
||||||
|
return expand_glob_or_file(combined)
|
||||||
|
return expand_glob_or_file(ref_path)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_test_paths(ref_arg: str | Path) -> List[str]:
|
||||||
|
ref_path = Path(ref_arg)
|
||||||
|
if ref_path.suffix == ".json" and ref_path.exists():
|
||||||
|
cfg = load_json(ref_path)
|
||||||
|
cfg_base = ref_path.parent
|
||||||
|
explicit = cfg.get("test_glob") or cfg.get("test_path") or ""
|
||||||
|
if explicit:
|
||||||
|
return expand_glob_or_file(cfg_base / explicit)
|
||||||
|
train_ref = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||||
|
if not train_ref:
|
||||||
|
raise SystemExit("reference config has no data_glob/data_path")
|
||||||
|
return infer_test_paths(cfg_base / train_ref)
|
||||||
|
|
||||||
|
path = Path(ref_arg)
|
||||||
|
text = str(path)
|
||||||
|
candidates: List[Path] = []
|
||||||
|
if any(ch in text for ch in ["*", "?", "["]):
|
||||||
|
parent = path.parent if path.parent != Path("") else Path(".")
|
||||||
|
name = path.name
|
||||||
|
if "train" in name:
|
||||||
|
candidates.append(parent / name.replace("train", "test"))
|
||||||
|
candidates.append(parent / name.replace("TRAIN", "TEST"))
|
||||||
|
candidates.append(parent / "test*.csv.gz")
|
||||||
|
candidates.append(parent / "test*.csv")
|
||||||
|
else:
|
||||||
|
parent = path.parent if path.parent != Path("") else Path(".")
|
||||||
|
name = path.name
|
||||||
|
if "train" in name:
|
||||||
|
candidates.append(parent / name.replace("train", "test"))
|
||||||
|
candidates.append(parent / name.replace("TRAIN", "TEST"))
|
||||||
|
candidates.append(parent / "test*.csv.gz")
|
||||||
|
candidates.append(parent / "test*.csv")
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
matches = expand_glob_or_file(candidate)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
raise SystemExit(f"could not infer test files from reference: {ref_arg}")
|
||||||
|
|
||||||
|
|
||||||
|
def expand_glob_or_file(path_like: str | Path) -> List[str]:
|
||||||
|
path = Path(path_like)
|
||||||
|
text = str(path)
|
||||||
|
if any(ch in text for ch in ["*", "?", "["]):
|
||||||
|
parent = path.parent if path.parent != Path("") else Path(".")
|
||||||
|
matches = sorted(parent.glob(path.name))
|
||||||
|
return [str(p.resolve()) for p in matches]
|
||||||
|
if path.exists():
|
||||||
|
return [str(path.resolve())]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def load_split_columns(split_path: str | Path) -> Tuple[str, List[str], List[str], List[str]]:
|
||||||
|
split = load_json(split_path)
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_all = [c for c in split["discrete"] if c != time_col]
|
||||||
|
label_cols = [c for c in disc_all if c.lower().startswith("attack")]
|
||||||
|
disc_cols = [c for c in disc_all if c not in label_cols]
|
||||||
|
return time_col, cont_cols, disc_cols, label_cols
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_path: str | Path, disc_cols: Sequence[str]) -> Tuple[Dict[str, Dict[str, int]], List[int]]:
|
||||||
|
vocab = load_json(vocab_path)["vocab"]
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
return vocab, vocab_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats_vectors(stats_path: str | Path, cont_cols: Sequence[str]) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
stats = load_json(stats_path)
|
||||||
|
mean = stats.get("raw_mean", stats["mean"])
|
||||||
|
std = stats.get("raw_std", stats["std"])
|
||||||
|
mean_vec = np.asarray([float(mean[c]) for c in cont_cols], dtype=np.float32)
|
||||||
|
std_vec = np.asarray([max(float(std[c]), 1e-6) for c in cont_cols], dtype=np.float32)
|
||||||
|
return mean_vec, std_vec
|
||||||
|
|
||||||
|
|
||||||
|
def load_rows(
|
||||||
|
path: str | Path,
|
||||||
|
cont_cols: Sequence[str],
|
||||||
|
disc_cols: Sequence[str],
|
||||||
|
label_cols: Optional[Sequence[str]] = None,
|
||||||
|
vocab: Optional[Dict[str, Dict[str, int]]] = None,
|
||||||
|
max_rows: Optional[int] = None,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||||
|
cont_rows: List[List[float]] = []
|
||||||
|
disc_rows: List[List[int]] = []
|
||||||
|
label_rows: List[int] = []
|
||||||
|
label_cols = list(label_cols or [])
|
||||||
|
|
||||||
|
with open_csv(path) as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for idx, row in enumerate(reader):
|
||||||
|
cont_rows.append([float(row[c]) for c in cont_cols])
|
||||||
|
if disc_cols:
|
||||||
|
if vocab is None:
|
||||||
|
disc_rows.append([int(float(row[c])) for c in disc_cols])
|
||||||
|
else:
|
||||||
|
encoded = []
|
||||||
|
for c in disc_cols:
|
||||||
|
mapping = vocab[c]
|
||||||
|
encoded.append(mapping.get(row.get(c, ""), mapping.get("<UNK>", 0)))
|
||||||
|
disc_rows.append(encoded)
|
||||||
|
if label_cols:
|
||||||
|
label = 0
|
||||||
|
for c in label_cols:
|
||||||
|
try:
|
||||||
|
label = max(label, int(float(row.get(c, 0) or 0)))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
label_rows.append(label)
|
||||||
|
if max_rows is not None and idx + 1 >= max_rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
cont = np.asarray(cont_rows, dtype=np.float32) if cont_rows else np.zeros((0, len(cont_cols)), dtype=np.float32)
|
||||||
|
disc = np.asarray(disc_rows, dtype=np.int64) if disc_rows else np.zeros((0, len(disc_cols)), dtype=np.int64)
|
||||||
|
labels = None
|
||||||
|
if label_cols:
|
||||||
|
labels = np.asarray(label_rows, dtype=np.int64)
|
||||||
|
return cont, disc, labels
|
||||||
|
|
||||||
|
|
||||||
|
def window_array(
|
||||||
|
cont: np.ndarray,
|
||||||
|
disc: np.ndarray,
|
||||||
|
labels: Optional[np.ndarray],
|
||||||
|
seq_len: int,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
max_windows: Optional[int] = None,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||||
|
if stride is None or stride <= 0:
|
||||||
|
stride = seq_len
|
||||||
|
cont_windows: List[np.ndarray] = []
|
||||||
|
disc_windows: List[np.ndarray] = []
|
||||||
|
label_windows: List[int] = []
|
||||||
|
if cont.shape[0] < seq_len:
|
||||||
|
return (
|
||||||
|
np.zeros((0, seq_len, cont.shape[1]), dtype=np.float32),
|
||||||
|
np.zeros((0, seq_len, disc.shape[1]), dtype=np.int64),
|
||||||
|
np.zeros((0,), dtype=np.int64) if labels is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for start in range(0, cont.shape[0] - seq_len + 1, stride):
|
||||||
|
end = start + seq_len
|
||||||
|
cont_windows.append(cont[start:end])
|
||||||
|
disc_windows.append(disc[start:end])
|
||||||
|
if labels is not None:
|
||||||
|
label_windows.append(int(labels[start:end].max()))
|
||||||
|
count += 1
|
||||||
|
if max_windows is not None and count >= max_windows:
|
||||||
|
break
|
||||||
|
|
||||||
|
cont_out = np.asarray(cont_windows, dtype=np.float32)
|
||||||
|
disc_out = np.asarray(disc_windows, dtype=np.int64)
|
||||||
|
label_out = np.asarray(label_windows, dtype=np.int64) if labels is not None else None
|
||||||
|
return cont_out, disc_out, label_out
|
||||||
|
|
||||||
|
|
||||||
|
def load_windows_from_paths(
|
||||||
|
paths: Sequence[str],
|
||||||
|
cont_cols: Sequence[str],
|
||||||
|
disc_cols: Sequence[str],
|
||||||
|
seq_len: int,
|
||||||
|
vocab: Optional[Dict[str, Dict[str, int]]] = None,
|
||||||
|
label_cols: Optional[Sequence[str]] = None,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
max_windows: Optional[int] = None,
|
||||||
|
max_rows_per_file: Optional[int] = None,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
|
||||||
|
cont_all: List[np.ndarray] = []
|
||||||
|
disc_all: List[np.ndarray] = []
|
||||||
|
label_all: List[np.ndarray] = []
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
remaining = None if max_windows is None else max(0, max_windows - total)
|
||||||
|
if remaining == 0:
|
||||||
|
break
|
||||||
|
cont, disc, labels = load_rows(
|
||||||
|
path,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
label_cols=label_cols,
|
||||||
|
vocab=vocab,
|
||||||
|
max_rows=max_rows_per_file,
|
||||||
|
)
|
||||||
|
w_cont, w_disc, w_labels = window_array(
|
||||||
|
cont,
|
||||||
|
disc,
|
||||||
|
labels,
|
||||||
|
seq_len=seq_len,
|
||||||
|
stride=stride,
|
||||||
|
max_windows=remaining,
|
||||||
|
)
|
||||||
|
if w_cont.size == 0:
|
||||||
|
continue
|
||||||
|
cont_all.append(w_cont)
|
||||||
|
disc_all.append(w_disc)
|
||||||
|
if w_labels is not None:
|
||||||
|
label_all.append(w_labels)
|
||||||
|
total += w_cont.shape[0]
|
||||||
|
|
||||||
|
if not cont_all:
|
||||||
|
empty_cont = np.zeros((0, seq_len, len(cont_cols)), dtype=np.float32)
|
||||||
|
empty_disc = np.zeros((0, seq_len, len(disc_cols)), dtype=np.int64)
|
||||||
|
empty_labels = np.zeros((0,), dtype=np.int64) if label_cols else None
|
||||||
|
return empty_cont, empty_disc, empty_labels
|
||||||
|
|
||||||
|
cont_out = np.concatenate(cont_all, axis=0)
|
||||||
|
disc_out = np.concatenate(disc_all, axis=0)
|
||||||
|
label_out = np.concatenate(label_all, axis=0) if label_all else None
|
||||||
|
return cont_out, disc_out, label_out
|
||||||
|
|
||||||
|
|
||||||
|
def filter_windows_by_label(
|
||||||
|
cont_windows: np.ndarray,
|
||||||
|
disc_windows: np.ndarray,
|
||||||
|
labels: Optional[np.ndarray],
|
||||||
|
target_label: int,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
if labels is None:
|
||||||
|
raise ValueError("labels are required for label filtering")
|
||||||
|
mask = labels == int(target_label)
|
||||||
|
return cont_windows[mask], disc_windows[mask], labels[mask]
|
||||||
|
|
||||||
|
|
||||||
|
def standardize_cont_windows(
|
||||||
|
cont_windows: np.ndarray,
|
||||||
|
mean_vec: np.ndarray,
|
||||||
|
std_vec: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
|
return (cont_windows - mean_vec.reshape(1, 1, -1)) / std_vec.reshape(1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def build_flat_window_vectors(
|
||||||
|
cont_windows: np.ndarray,
|
||||||
|
disc_windows: np.ndarray,
|
||||||
|
mean_vec: np.ndarray,
|
||||||
|
std_vec: np.ndarray,
|
||||||
|
vocab_sizes: Sequence[int],
|
||||||
|
) -> np.ndarray:
|
||||||
|
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
|
||||||
|
if disc_windows.size == 0:
|
||||||
|
return cont_norm.astype(np.float32)
|
||||||
|
disc_scale = np.asarray([max(v - 1, 1) for v in vocab_sizes], dtype=np.float32).reshape(1, 1, -1)
|
||||||
|
disc_norm = (disc_windows.astype(np.float32) / disc_scale).reshape(disc_windows.shape[0], -1)
|
||||||
|
return np.concatenate([cont_norm, disc_norm], axis=1).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def build_histogram_embeddings(
|
||||||
|
cont_windows: np.ndarray,
|
||||||
|
disc_windows: np.ndarray,
|
||||||
|
mean_vec: np.ndarray,
|
||||||
|
std_vec: np.ndarray,
|
||||||
|
vocab_sizes: Sequence[int],
|
||||||
|
) -> np.ndarray:
|
||||||
|
cont_norm = standardize_cont_windows(cont_windows, mean_vec, std_vec).reshape(cont_windows.shape[0], -1)
|
||||||
|
if disc_windows.size == 0:
|
||||||
|
return cont_norm.astype(np.float32)
|
||||||
|
hist_features: List[np.ndarray] = []
|
||||||
|
for disc_idx, vocab_size in enumerate(vocab_sizes):
|
||||||
|
one_hist = np.zeros((disc_windows.shape[0], vocab_size), dtype=np.float32)
|
||||||
|
col_values = disc_windows[:, :, disc_idx]
|
||||||
|
for value in range(vocab_size):
|
||||||
|
one_hist[:, value] = (col_values == value).mean(axis=1)
|
||||||
|
hist_features.append(one_hist)
|
||||||
|
disc_hist = np.concatenate(hist_features, axis=1) if hist_features else np.zeros((cont_windows.shape[0], 0), dtype=np.float32)
|
||||||
|
return np.concatenate([cont_norm, disc_hist], axis=1).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_indices(n_items: int, max_items: Optional[int], seed: int) -> np.ndarray:
|
||||||
|
if max_items is None or n_items <= max_items:
|
||||||
|
return np.arange(n_items, dtype=np.int64)
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
return np.sort(rng.choice(n_items, size=max_items, replace=False))
|
||||||
|
|
||||||
|
|
||||||
|
def subset_by_indices(array: np.ndarray, indices: np.ndarray) -> np.ndarray:
|
||||||
|
if array is None:
|
||||||
|
return array
|
||||||
|
return array[indices]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_corr_matrix(rows: np.ndarray) -> np.ndarray:
|
||||||
|
if rows.shape[0] < 2:
|
||||||
|
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
|
||||||
|
matrix = np.corrcoef(rows, rowvar=False)
|
||||||
|
matrix = np.nan_to_num(matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
return matrix.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_lagged_corr_matrix(rows: np.ndarray, lag: int = 1) -> np.ndarray:
|
||||||
|
if rows.shape[0] <= lag:
|
||||||
|
return np.zeros((rows.shape[1], rows.shape[1]), dtype=np.float32)
|
||||||
|
x = rows[:-lag]
|
||||||
|
y = rows[lag:]
|
||||||
|
x = x - x.mean(axis=0, keepdims=True)
|
||||||
|
y = y - y.mean(axis=0, keepdims=True)
|
||||||
|
cov = x.T @ y / max(x.shape[0] - 1, 1)
|
||||||
|
std_x = np.sqrt(np.maximum((x ** 2).sum(axis=0) / max(x.shape[0] - 1, 1), 1e-8))
|
||||||
|
std_y = np.sqrt(np.maximum((y ** 2).sum(axis=0) / max(y.shape[0] - 1, 1), 1e-8))
|
||||||
|
denom = np.outer(std_x, std_y)
|
||||||
|
corr = cov / np.maximum(denom, 1e-8)
|
||||||
|
return np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def split_process_groups(feature_names: Sequence[str]) -> Dict[str, List[int]]:
|
||||||
|
groups: Dict[str, List[int]] = {}
|
||||||
|
for idx, name in enumerate(feature_names):
|
||||||
|
prefix = name.split("_", 1)[0]
|
||||||
|
groups.setdefault(prefix, []).append(idx)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
def compute_average_psd(cont_windows: np.ndarray) -> np.ndarray:
|
||||||
|
if cont_windows.size == 0:
|
||||||
|
return np.zeros((0, 0), dtype=np.float32)
|
||||||
|
centered = cont_windows - cont_windows.mean(axis=1, keepdims=True)
|
||||||
|
spectrum = np.fft.rfft(centered, axis=1)
|
||||||
|
power = (np.abs(spectrum) ** 2).mean(axis=0).T
|
||||||
|
power = power.astype(np.float32)
|
||||||
|
norm = power.sum(axis=1, keepdims=True)
|
||||||
|
norm[norm <= 0] = 1.0
|
||||||
|
return power / norm
|
||||||
|
|
||||||
|
|
||||||
|
def psd_distance_stats(real_psd: np.ndarray, gen_psd: np.ndarray) -> Dict[str, float]:
|
||||||
|
if real_psd.size == 0 or gen_psd.size == 0:
|
||||||
|
return {
|
||||||
|
"avg_psd_l1": float("nan"),
|
||||||
|
"avg_psd_cosine": float("nan"),
|
||||||
|
"avg_low_high_ratio_abs_diff": float("nan"),
|
||||||
|
}
|
||||||
|
l1 = np.abs(real_psd - gen_psd).mean(axis=1)
|
||||||
|
cosine = []
|
||||||
|
ratio_diffs = []
|
||||||
|
n_freq = real_psd.shape[1]
|
||||||
|
split = max(1, n_freq // 4)
|
||||||
|
for i in range(real_psd.shape[0]):
|
||||||
|
rv = real_psd[i]
|
||||||
|
gv = gen_psd[i]
|
||||||
|
denom = max(np.linalg.norm(rv) * np.linalg.norm(gv), 1e-8)
|
||||||
|
cosine.append(1.0 - float(np.dot(rv, gv) / denom))
|
||||||
|
r_low = float(rv[:split].sum())
|
||||||
|
r_high = float(rv[split:].sum())
|
||||||
|
g_low = float(gv[:split].sum())
|
||||||
|
g_high = float(gv[split:].sum())
|
||||||
|
r_ratio = r_low / max(r_high, 1e-8)
|
||||||
|
g_ratio = g_low / max(g_high, 1e-8)
|
||||||
|
ratio_diffs.append(abs(r_ratio - g_ratio))
|
||||||
|
return {
|
||||||
|
"avg_psd_l1": float(np.mean(l1)),
|
||||||
|
"avg_psd_cosine": float(np.mean(cosine)),
|
||||||
|
"avg_low_high_ratio_abs_diff": float(np.mean(ratio_diffs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def pairwise_sq_dists(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||||
|
x_norm = (x ** 2).sum(axis=1, keepdims=True)
|
||||||
|
y_norm = (y ** 2).sum(axis=1, keepdims=True).T
|
||||||
|
d2 = x_norm + y_norm - 2.0 * (x @ y.T)
|
||||||
|
return np.maximum(d2, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def median_heuristic_gamma(x: np.ndarray, y: np.ndarray) -> float:
|
||||||
|
joined = np.concatenate([x, y], axis=0)
|
||||||
|
if joined.shape[0] <= 1:
|
||||||
|
return 1.0
|
||||||
|
d2 = pairwise_sq_dists(joined[: min(joined.shape[0], 256)], joined[: min(joined.shape[0], 256)])
|
||||||
|
upper = d2[np.triu_indices_from(d2, k=1)]
|
||||||
|
upper = upper[upper > 0]
|
||||||
|
if upper.size == 0:
|
||||||
|
return 1.0
|
||||||
|
median = float(np.median(upper))
|
||||||
|
return 1.0 / max(2.0 * median, 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
def rbf_mmd(x: np.ndarray, y: np.ndarray, gamma: Optional[float] = None) -> Tuple[float, float]:
|
||||||
|
if x.shape[0] == 0 or y.shape[0] == 0:
|
||||||
|
return float("nan"), 1.0
|
||||||
|
if gamma is None:
|
||||||
|
gamma = median_heuristic_gamma(x, y)
|
||||||
|
k_xx = np.exp(-gamma * pairwise_sq_dists(x, x))
|
||||||
|
k_yy = np.exp(-gamma * pairwise_sq_dists(y, y))
|
||||||
|
k_xy = np.exp(-gamma * pairwise_sq_dists(x, y))
|
||||||
|
m = max(x.shape[0], 1)
|
||||||
|
n = max(y.shape[0], 1)
|
||||||
|
if m > 1:
|
||||||
|
term_xx = (k_xx.sum() - np.trace(k_xx)) / (m * (m - 1))
|
||||||
|
else:
|
||||||
|
term_xx = 0.0
|
||||||
|
if n > 1:
|
||||||
|
term_yy = (k_yy.sum() - np.trace(k_yy)) / (n * (n - 1))
|
||||||
|
else:
|
||||||
|
term_yy = 0.0
|
||||||
|
term_xy = 2.0 * k_xy.mean()
|
||||||
|
return float(term_xx + term_yy - term_xy), float(gamma)
|
||||||
|
|
||||||
|
|
||||||
|
def duplicate_rate(vectors: np.ndarray, decimals: int = 5) -> float:
|
||||||
|
if vectors.shape[0] <= 1:
|
||||||
|
return 0.0
|
||||||
|
rounded = np.round(vectors, decimals=decimals)
|
||||||
|
unique = np.unique(rounded, axis=0).shape[0]
|
||||||
|
return float(1.0 - unique / vectors.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match_rate(query: np.ndarray, base: np.ndarray, decimals: int = 5) -> float:
|
||||||
|
if query.shape[0] == 0 or base.shape[0] == 0:
|
||||||
|
return 0.0
|
||||||
|
rounded_base = {tuple(row.tolist()) for row in np.round(base, decimals=decimals)}
|
||||||
|
rounded_query = np.round(query, decimals=decimals)
|
||||||
|
matches = sum(1 for row in rounded_query if tuple(row.tolist()) in rounded_base)
|
||||||
|
return float(matches / query.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
def nearest_neighbor_distance_stats(query: np.ndarray, base: np.ndarray, batch_size: int = 128) -> Dict[str, float]:
|
||||||
|
if query.shape[0] == 0 or base.shape[0] == 0:
|
||||||
|
return {"mean": float("nan"), "median": float("nan"), "min": float("nan")}
|
||||||
|
dists: List[np.ndarray] = []
|
||||||
|
for start in range(0, query.shape[0], batch_size):
|
||||||
|
end = start + batch_size
|
||||||
|
chunk = query[start:end]
|
||||||
|
d2 = pairwise_sq_dists(chunk, base)
|
||||||
|
dists.append(np.sqrt(np.min(d2, axis=1)))
|
||||||
|
values = np.concatenate(dists, axis=0)
|
||||||
|
return {
|
||||||
|
"mean": float(values.mean()),
|
||||||
|
"median": float(np.median(values)),
|
||||||
|
"min": float(values.min()),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def one_nn_two_sample_accuracy(real_vecs: np.ndarray, gen_vecs: np.ndarray) -> float:
|
||||||
|
if real_vecs.shape[0] < 2 or gen_vecs.shape[0] < 2:
|
||||||
|
return float("nan")
|
||||||
|
x = np.concatenate([real_vecs, gen_vecs], axis=0)
|
||||||
|
y = np.concatenate(
|
||||||
|
[
|
||||||
|
np.zeros(real_vecs.shape[0], dtype=np.int64),
|
||||||
|
np.ones(gen_vecs.shape[0], dtype=np.int64),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
d2 = pairwise_sq_dists(x, x)
|
||||||
|
np.fill_diagonal(d2, np.inf)
|
||||||
|
nn = np.argmin(d2, axis=1)
|
||||||
|
pred = y[nn]
|
||||||
|
return float((pred == y).mean())
|
||||||
|
|
||||||
|
|
||||||
|
def binary_classification_curves(y_true: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
order = np.argsort(-scores)
|
||||||
|
y = y_true[order].astype(np.int64)
|
||||||
|
scores_sorted = scores[order]
|
||||||
|
tp = np.cumsum(y == 1)
|
||||||
|
fp = np.cumsum(y == 0)
|
||||||
|
positives = max(int((y_true == 1).sum()), 1)
|
||||||
|
negatives = max(int((y_true == 0).sum()), 1)
|
||||||
|
tpr = tp / positives
|
||||||
|
fpr = fp / negatives
|
||||||
|
precision = tp / np.maximum(tp + fp, 1)
|
||||||
|
recall = tpr
|
||||||
|
return scores_sorted, fpr, tpr, precision
|
||||||
|
|
||||||
|
|
||||||
|
def binary_auroc(y_true: np.ndarray, scores: np.ndarray) -> float:
|
||||||
|
if len(np.unique(y_true)) < 2:
|
||||||
|
return float("nan")
|
||||||
|
_, fpr, tpr, _ = binary_classification_curves(y_true, scores)
|
||||||
|
fpr = np.concatenate([[0.0], fpr, [1.0]])
|
||||||
|
tpr = np.concatenate([[0.0], tpr, [1.0]])
|
||||||
|
return float(np.trapz(tpr, fpr))
|
||||||
|
|
||||||
|
|
||||||
|
def binary_average_precision(y_true: np.ndarray, scores: np.ndarray) -> float:
|
||||||
|
if len(np.unique(y_true)) < 2:
|
||||||
|
return float("nan")
|
||||||
|
_, _, _, precision = binary_classification_curves(y_true, scores)
|
||||||
|
positives = max(int((y_true == 1).sum()), 1)
|
||||||
|
order = np.argsort(-scores)
|
||||||
|
y = y_true[order].astype(np.int64)
|
||||||
|
tp = np.cumsum(y == 1)
|
||||||
|
recall = tp / positives
|
||||||
|
precision = tp / np.arange(1, len(tp) + 1)
|
||||||
|
recall = np.concatenate([[0.0], recall])
|
||||||
|
precision = np.concatenate([[precision[0] if precision.size else 1.0], precision])
|
||||||
|
return float(np.sum((recall[1:] - recall[:-1]) * precision[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def binary_f1_at_threshold(y_true: np.ndarray, scores: np.ndarray, threshold: float) -> Dict[str, float]:
|
||||||
|
pred = (scores >= threshold).astype(np.int64)
|
||||||
|
tp = int(((pred == 1) & (y_true == 1)).sum())
|
||||||
|
fp = int(((pred == 1) & (y_true == 0)).sum())
|
||||||
|
fn = int(((pred == 0) & (y_true == 1)).sum())
|
||||||
|
precision = tp / max(tp + fp, 1)
|
||||||
|
recall = tp / max(tp + fn, 1)
|
||||||
|
if precision + recall == 0:
|
||||||
|
f1 = 0.0
|
||||||
|
else:
|
||||||
|
f1 = 2.0 * precision * recall / (precision + recall)
|
||||||
|
return {"threshold": float(threshold), "precision": float(precision), "recall": float(recall), "f1": float(f1)}
|
||||||
|
|
||||||
|
|
||||||
|
def best_binary_f1(y_true: np.ndarray, scores: np.ndarray) -> Dict[str, float]:
|
||||||
|
if y_true.size == 0:
|
||||||
|
return {"threshold": float("nan"), "precision": float("nan"), "recall": float("nan"), "f1": float("nan")}
|
||||||
|
thresholds = np.unique(scores)
|
||||||
|
best = {"threshold": float(thresholds[0]), "precision": 0.0, "recall": 0.0, "f1": -1.0}
|
||||||
|
for threshold in thresholds:
|
||||||
|
stats = binary_f1_at_threshold(y_true, scores, float(threshold))
|
||||||
|
if stats["f1"] > best["f1"]:
|
||||||
|
best = stats
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed: int) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
772
example/evaluate_comprehensive.py
Normal file
772
example/evaluate_comprehensive.py
Normal file
@@ -0,0 +1,772 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Comprehensive evaluation suite for generated ICS feature sequences."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
from eval_utils import (
|
||||||
|
best_binary_f1,
|
||||||
|
binary_auroc,
|
||||||
|
binary_average_precision,
|
||||||
|
binary_f1_at_threshold,
|
||||||
|
build_flat_window_vectors,
|
||||||
|
build_histogram_embeddings,
|
||||||
|
compute_average_psd,
|
||||||
|
compute_corr_matrix,
|
||||||
|
duplicate_rate,
|
||||||
|
exact_match_rate,
|
||||||
|
infer_test_paths,
|
||||||
|
load_json,
|
||||||
|
load_split_columns,
|
||||||
|
load_stats_vectors,
|
||||||
|
load_vocab,
|
||||||
|
load_windows_from_paths,
|
||||||
|
nearest_neighbor_distance_stats,
|
||||||
|
one_nn_two_sample_accuracy,
|
||||||
|
psd_distance_stats,
|
||||||
|
rbf_mmd,
|
||||||
|
resolve_reference_paths,
|
||||||
|
sample_indices,
|
||||||
|
set_random_seed,
|
||||||
|
split_process_groups,
|
||||||
|
standardize_cont_windows,
|
||||||
|
)
|
||||||
|
from platform_utils import resolve_device
|
||||||
|
from window_models import MLPAutoencoder, MLPClassifier, MLPRegressor
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser = argparse.ArgumentParser(description="Comprehensive evaluation for generated.csv.")
|
||||||
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
||||||
|
parser.add_argument("--reference", default=str(base_dir / "config.json"))
|
||||||
|
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
||||||
|
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
|
||||||
|
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||||
|
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||||
|
parser.add_argument("--out", default=str(base_dir / "results" / "comprehensive_eval.json"))
|
||||||
|
parser.add_argument("--seq-len", type=int, default=0)
|
||||||
|
parser.add_argument("--stride", type=int, default=0, help="0 means non-overlapping windows")
|
||||||
|
parser.add_argument("--max-train-windows", type=int, default=1024)
|
||||||
|
parser.add_argument("--max-generated-windows", type=int, default=1024)
|
||||||
|
parser.add_argument("--max-test-windows", type=int, default=1024)
|
||||||
|
parser.add_argument("--max-rows-per-file", type=int, default=0)
|
||||||
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||||
|
parser.add_argument("--seed", type=int, default=1337)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
|
parser.add_argument("--classifier-epochs", type=int, default=12)
|
||||||
|
parser.add_argument("--predictor-epochs", type=int, default=12)
|
||||||
|
parser.add_argument("--detector-epochs", type=int, default=16)
|
||||||
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
||||||
|
parser.add_argument("--detector-threshold-quantile", type=float, default=0.995)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_rows(windows: np.ndarray) -> np.ndarray:
|
||||||
|
if windows.size == 0:
|
||||||
|
return np.zeros((0, 0), dtype=np.float32)
|
||||||
|
return windows.reshape(-1, windows.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def lagged_corr_from_windows(windows: np.ndarray, lag: int = 1) -> np.ndarray:
|
||||||
|
if windows.shape[0] == 0 or windows.shape[1] <= lag:
|
||||||
|
return np.zeros((windows.shape[-1], windows.shape[-1]), dtype=np.float32)
|
||||||
|
x = windows[:, :-lag, :].reshape(-1, windows.shape[-1])
|
||||||
|
y = windows[:, lag:, :].reshape(-1, windows.shape[-1])
|
||||||
|
x = x - x.mean(axis=0, keepdims=True)
|
||||||
|
y = y - y.mean(axis=0, keepdims=True)
|
||||||
|
cov = x.T @ y / max(x.shape[0] - 1, 1)
|
||||||
|
std_x = np.sqrt(np.maximum((x ** 2).sum(axis=0) / max(x.shape[0] - 1, 1), 1e-8))
|
||||||
|
std_y = np.sqrt(np.maximum((y ** 2).sum(axis=0) / max(y.shape[0] - 1, 1), 1e-8))
|
||||||
|
denom = np.outer(std_x, std_y)
|
||||||
|
corr = cov / np.maximum(denom, 1e-8)
|
||||||
|
return np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_abs_matrix_diff(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
return float(np.abs(a - b).mean()) if a.size and b.size else float("nan")
|
||||||
|
|
||||||
|
|
||||||
|
def fro_matrix_diff(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
return float(np.linalg.norm(a - b)) if a.size and b.size else float("nan")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_mean(values: Iterable[float]) -> float:
|
||||||
|
vals = [float(v) for v in values if v is not None and not math.isnan(float(v))]
|
||||||
|
return float(sum(vals) / len(vals)) if vals else float("nan")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_median(values: Sequence[float]) -> float:
|
||||||
|
if not values:
|
||||||
|
return float("nan")
|
||||||
|
arr = sorted(float(v) for v in values)
|
||||||
|
mid = len(arr) // 2
|
||||||
|
if len(arr) % 2 == 1:
|
||||||
|
return float(arr[mid])
|
||||||
|
return float(0.5 * (arr[mid - 1] + arr[mid]))
|
||||||
|
|
||||||
|
|
||||||
|
def dwell_and_steps(series: Sequence[float]) -> Dict[str, float]:
|
||||||
|
if not series:
|
||||||
|
return {
|
||||||
|
"num_changes": float("nan"),
|
||||||
|
"mean_dwell": float("nan"),
|
||||||
|
"median_dwell": float("nan"),
|
||||||
|
"mean_step": float("nan"),
|
||||||
|
"median_step": float("nan"),
|
||||||
|
}
|
||||||
|
changes = 0
|
||||||
|
dwells: List[float] = []
|
||||||
|
steps: List[float] = []
|
||||||
|
current = float(series[0])
|
||||||
|
dwell = 1
|
||||||
|
for value in series[1:]:
|
||||||
|
value = float(value)
|
||||||
|
if value == current:
|
||||||
|
dwell += 1
|
||||||
|
continue
|
||||||
|
changes += 1
|
||||||
|
dwells.append(float(dwell))
|
||||||
|
steps.append(abs(value - current))
|
||||||
|
current = value
|
||||||
|
dwell = 1
|
||||||
|
dwells.append(float(dwell))
|
||||||
|
return {
|
||||||
|
"num_changes": float(changes),
|
||||||
|
"mean_dwell": safe_mean(dwells),
|
||||||
|
"median_dwell": safe_median(dwells),
|
||||||
|
"mean_step": safe_mean(steps),
|
||||||
|
"median_step": safe_median(steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def controller_stats(series: Sequence[float], vmin: float, vmax: float) -> Dict[str, float]:
|
||||||
|
if not series:
|
||||||
|
return {"saturation_ratio": float("nan"), "change_rate": float("nan"), "step_median": float("nan")}
|
||||||
|
rng = vmax - vmin
|
||||||
|
tol = 0.01 * rng if rng > 0 else 0.0
|
||||||
|
sat = sum(1 for value in series if value <= vmin + tol or value >= vmax - tol) / len(series)
|
||||||
|
changes = 0
|
||||||
|
steps: List[float] = []
|
||||||
|
prev = float(series[0])
|
||||||
|
for value in series[1:]:
|
||||||
|
value = float(value)
|
||||||
|
if value != prev:
|
||||||
|
changes += 1
|
||||||
|
steps.append(abs(value - prev))
|
||||||
|
prev = value
|
||||||
|
change_rate = changes / max(len(series) - 1, 1)
|
||||||
|
return {
|
||||||
|
"saturation_ratio": float(sat),
|
||||||
|
"change_rate": float(change_rate),
|
||||||
|
"step_median": safe_median(steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def actuator_stats(series: Sequence[float]) -> Dict[str, float]:
|
||||||
|
if not series:
|
||||||
|
return {
|
||||||
|
"unique_ratio": float("nan"),
|
||||||
|
"top1_mass": float("nan"),
|
||||||
|
"top3_mass": float("nan"),
|
||||||
|
"median_dwell": float("nan"),
|
||||||
|
}
|
||||||
|
rounded = [round(float(v), 2) for v in series]
|
||||||
|
counts: Dict[float, int] = {}
|
||||||
|
for value in rounded:
|
||||||
|
counts[value] = counts.get(value, 0) + 1
|
||||||
|
top = sorted(counts.values(), reverse=True)
|
||||||
|
dwells: List[float] = []
|
||||||
|
current = rounded[0]
|
||||||
|
dwell = 1
|
||||||
|
for value in rounded[1:]:
|
||||||
|
if value == current:
|
||||||
|
dwell += 1
|
||||||
|
else:
|
||||||
|
dwells.append(float(dwell))
|
||||||
|
current = value
|
||||||
|
dwell = 1
|
||||||
|
dwells.append(float(dwell))
|
||||||
|
return {
|
||||||
|
"unique_ratio": float(len(counts) / len(rounded)),
|
||||||
|
"top1_mass": float(top[0] / len(rounded)) if top else float("nan"),
|
||||||
|
"top3_mass": float(sum(top[:3]) / len(rounded)) if top else float("nan"),
|
||||||
|
"median_dwell": safe_median(dwells),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def pv_stats(series: Sequence[float]) -> Dict[str, float]:
|
||||||
|
if not series:
|
||||||
|
return {"q05": float("nan"), "q50": float("nan"), "q95": float("nan"), "tail_ratio": float("nan")}
|
||||||
|
xs = sorted(float(v) for v in series)
|
||||||
|
n = len(xs)
|
||||||
|
|
||||||
|
def q(prob: float) -> float:
|
||||||
|
idx = max(0, min(n - 1, int(round(prob * (n - 1)))))
|
||||||
|
return float(xs[idx])
|
||||||
|
|
||||||
|
q05 = q(0.05)
|
||||||
|
q50 = q(0.5)
|
||||||
|
q95 = q(0.95)
|
||||||
|
denom = q50 - q05
|
||||||
|
tail_ratio = (q95 - q50) / denom if denom != 0 else float("nan")
|
||||||
|
return {"q05": q05, "q50": q50, "q95": q95, "tail_ratio": float(tail_ratio)}
|
||||||
|
|
||||||
|
|
||||||
|
def aux_stats(series: Sequence[float]) -> Dict[str, float]:
|
||||||
|
if not series:
|
||||||
|
return {"mean": float("nan"), "std": float("nan"), "lag1": float("nan")}
|
||||||
|
arr = np.asarray(series, dtype=np.float32)
|
||||||
|
mean = float(arr.mean())
|
||||||
|
std = float(arr.std(ddof=1)) if arr.size > 1 else 0.0
|
||||||
|
if arr.size < 2:
|
||||||
|
lag1 = float("nan")
|
||||||
|
else:
|
||||||
|
x = arr[:-1] - arr[:-1].mean()
|
||||||
|
y = arr[1:] - arr[1:].mean()
|
||||||
|
denom = max(float(np.linalg.norm(x) * np.linalg.norm(y)), 1e-8)
|
||||||
|
lag1 = float(np.dot(x, y) / denom)
|
||||||
|
return {"mean": mean, "std": std, "lag1": lag1}
|
||||||
|
|
||||||
|
|
||||||
|
def metric_differences(generated: Dict[str, Dict[str, float]], reference: Dict[str, Dict[str, float]]) -> Dict[str, float]:
|
||||||
|
bucket: Dict[str, List[float]] = {}
|
||||||
|
for feature, metrics in generated.items():
|
||||||
|
ref_metrics = reference.get(feature, {})
|
||||||
|
for key, value in metrics.items():
|
||||||
|
ref_value = ref_metrics.get(key)
|
||||||
|
if ref_value is None:
|
||||||
|
continue
|
||||||
|
if math.isnan(float(value)) or math.isnan(float(ref_value)):
|
||||||
|
continue
|
||||||
|
bucket.setdefault(key, []).append(abs(float(value) - float(ref_value)))
|
||||||
|
return {f"mean_abs_diff_{key}": safe_mean(values) for key, values in bucket.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_type_metrics(
|
||||||
|
feature_names: Sequence[str],
|
||||||
|
gen_rows: np.ndarray,
|
||||||
|
real_rows: np.ndarray,
|
||||||
|
features: Sequence[str],
|
||||||
|
stat_fn,
|
||||||
|
use_real_bounds: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
feature_to_idx = {name: idx for idx, name in enumerate(feature_names)}
|
||||||
|
generated: Dict[str, Dict[str, float]] = {}
|
||||||
|
reference: Dict[str, Dict[str, float]] = {}
|
||||||
|
for feature in features:
|
||||||
|
if feature not in feature_to_idx:
|
||||||
|
continue
|
||||||
|
idx = feature_to_idx[feature]
|
||||||
|
gen_series = gen_rows[:, idx].astype(float).tolist()
|
||||||
|
real_series = real_rows[:, idx].astype(float).tolist()
|
||||||
|
if use_real_bounds:
|
||||||
|
vmin = float(np.min(real_rows[:, idx])) if real_rows.size else 0.0
|
||||||
|
vmax = float(np.max(real_rows[:, idx])) if real_rows.size else 0.0
|
||||||
|
generated[feature] = stat_fn(gen_series, vmin, vmax)
|
||||||
|
reference[feature] = stat_fn(real_series, vmin, vmax)
|
||||||
|
else:
|
||||||
|
generated[feature] = stat_fn(gen_series)
|
||||||
|
reference[feature] = stat_fn(real_series)
|
||||||
|
return {
|
||||||
|
"features": list(features),
|
||||||
|
"generated": generated,
|
||||||
|
"reference": reference,
|
||||||
|
"aggregates": metric_differences(generated, reference),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_loader(x: np.ndarray, y: Optional[np.ndarray], batch_size: int, shuffle: bool) -> DataLoader:
|
||||||
|
x_tensor = torch.tensor(x, dtype=torch.float32)
|
||||||
|
if y is None:
|
||||||
|
dataset = TensorDataset(x_tensor)
|
||||||
|
else:
|
||||||
|
y_tensor = torch.tensor(y, dtype=torch.float32)
|
||||||
|
dataset = TensorDataset(x_tensor, y_tensor)
|
||||||
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||||||
|
|
||||||
|
|
||||||
|
def split_train_val(
|
||||||
|
x: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
seed: int,
|
||||||
|
val_ratio: float = 0.2,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
idx = np.arange(x.shape[0])
|
||||||
|
rng.shuffle(idx)
|
||||||
|
cut = max(1, int(round(x.shape[0] * (1.0 - val_ratio))))
|
||||||
|
train_idx = idx[:cut]
|
||||||
|
val_idx = idx[cut:] if cut < idx.size else idx[:0]
|
||||||
|
if val_idx.size == 0:
|
||||||
|
val_idx = train_idx
|
||||||
|
return x[train_idx], y[train_idx], x[val_idx], y[val_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def train_classifier(
|
||||||
|
train_x: np.ndarray,
|
||||||
|
train_y: np.ndarray,
|
||||||
|
val_x: np.ndarray,
|
||||||
|
val_y: np.ndarray,
|
||||||
|
device: str,
|
||||||
|
hidden_dim: int,
|
||||||
|
batch_size: int,
|
||||||
|
epochs: int,
|
||||||
|
seed: int,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
if train_x.shape[0] < 2 or len(np.unique(train_y)) < 2:
|
||||||
|
return {"accuracy": float("nan"), "balanced_accuracy": float("nan"), "auroc": float("nan")}
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
model = MLPClassifier(train_x.shape[1], hidden_dim=hidden_dim).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
loss_fn = nn.BCEWithLogitsLoss()
|
||||||
|
loader = make_loader(train_x, train_y.reshape(-1, 1), batch_size=batch_size, shuffle=True)
|
||||||
|
model.train()
|
||||||
|
for _ in range(epochs):
|
||||||
|
for batch_x, batch_y in loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device).view(-1)
|
||||||
|
logits = model(batch_x)
|
||||||
|
loss = loss_fn(logits, batch_y)
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(torch.tensor(val_x, dtype=torch.float32, device=device)).cpu().numpy()
|
||||||
|
probs = 1.0 / (1.0 + np.exp(-logits))
|
||||||
|
pred = (probs >= 0.5).astype(np.int64)
|
||||||
|
y_true = val_y.astype(np.int64)
|
||||||
|
accuracy = float((pred == y_true).mean())
|
||||||
|
tp = ((pred == 1) & (y_true == 1)).sum()
|
||||||
|
tn = ((pred == 0) & (y_true == 0)).sum()
|
||||||
|
fp = ((pred == 1) & (y_true == 0)).sum()
|
||||||
|
fn = ((pred == 0) & (y_true == 1)).sum()
|
||||||
|
tpr = tp / max(tp + fn, 1)
|
||||||
|
tnr = tn / max(tn + fp, 1)
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"balanced_accuracy": float(0.5 * (tpr + tnr)),
|
||||||
|
"auroc": binary_auroc(y_true, probs),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def train_regressor(
|
||||||
|
train_x: np.ndarray,
|
||||||
|
train_y: np.ndarray,
|
||||||
|
eval_x: np.ndarray,
|
||||||
|
eval_y: np.ndarray,
|
||||||
|
device: str,
|
||||||
|
hidden_dim: int,
|
||||||
|
batch_size: int,
|
||||||
|
epochs: int,
|
||||||
|
seed: int,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
if train_x.shape[0] == 0 or eval_x.shape[0] == 0:
|
||||||
|
return {"rmse": float("nan"), "mae": float("nan")}
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
model = MLPRegressor(train_x.shape[1], train_y.shape[1], hidden_dim=hidden_dim).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
loader = make_loader(train_x, train_y, batch_size=batch_size, shuffle=True)
|
||||||
|
model.train()
|
||||||
|
for _ in range(epochs):
|
||||||
|
for batch_x, batch_y in loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
pred = model(batch_x)
|
||||||
|
loss = loss_fn(pred, batch_y)
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(torch.tensor(eval_x, dtype=torch.float32, device=device)).cpu().numpy()
|
||||||
|
diff = pred - eval_y
|
||||||
|
return {
|
||||||
|
"rmse": float(np.sqrt(np.mean(diff ** 2))),
|
||||||
|
"mae": float(np.mean(np.abs(diff))),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def train_autoencoder(
|
||||||
|
train_x: np.ndarray,
|
||||||
|
eval_x: np.ndarray,
|
||||||
|
eval_labels: np.ndarray,
|
||||||
|
device: str,
|
||||||
|
hidden_dim: int,
|
||||||
|
batch_size: int,
|
||||||
|
epochs: int,
|
||||||
|
seed: int,
|
||||||
|
threshold_quantile: float,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
if train_x.shape[0] == 0 or eval_x.shape[0] == 0:
|
||||||
|
return {
|
||||||
|
"auroc": float("nan"),
|
||||||
|
"auprc": float("nan"),
|
||||||
|
"threshold": float("nan"),
|
||||||
|
"f1": float("nan"),
|
||||||
|
"best_f1": float("nan"),
|
||||||
|
}
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
latent_dim = max(32, hidden_dim // 4)
|
||||||
|
model = MLPAutoencoder(train_x.shape[1], hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
loader = make_loader(train_x, None, batch_size=batch_size, shuffle=True)
|
||||||
|
model.train()
|
||||||
|
for _ in range(epochs):
|
||||||
|
for (batch_x,) in loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
recon = model(batch_x)
|
||||||
|
loss = loss_fn(recon, batch_x)
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_tensor = torch.tensor(train_x, dtype=torch.float32, device=device)
|
||||||
|
train_recon = model(train_tensor).cpu().numpy()
|
||||||
|
eval_tensor = torch.tensor(eval_x, dtype=torch.float32, device=device)
|
||||||
|
eval_recon = model(eval_tensor).cpu().numpy()
|
||||||
|
train_scores = np.mean((train_recon - train_x) ** 2, axis=1)
|
||||||
|
eval_scores = np.mean((eval_recon - eval_x) ** 2, axis=1)
|
||||||
|
threshold = float(np.quantile(train_scores, threshold_quantile))
|
||||||
|
f1_stats = binary_f1_at_threshold(eval_labels, eval_scores, threshold)
|
||||||
|
best_stats = best_binary_f1(eval_labels, eval_scores)
|
||||||
|
return {
|
||||||
|
"auroc": binary_auroc(eval_labels, eval_scores),
|
||||||
|
"auprc": binary_average_precision(eval_labels, eval_scores),
|
||||||
|
"threshold": threshold,
|
||||||
|
"f1": f1_stats["f1"],
|
||||||
|
"precision": f1_stats["precision"],
|
||||||
|
"recall": f1_stats["recall"],
|
||||||
|
"best_f1": best_stats["f1"],
|
||||||
|
"best_f1_threshold": best_stats["threshold"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bootstrap_to_size(array: np.ndarray, target_size: int, seed: int) -> np.ndarray:
|
||||||
|
if array.shape[0] == 0:
|
||||||
|
return array
|
||||||
|
if array.shape[0] >= target_size:
|
||||||
|
return array
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
idx = rng.choice(array.shape[0], size=target_size, replace=True)
|
||||||
|
return array[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def build_predictive_pairs(cont_windows: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
if cont_windows.shape[0] == 0 or cont_windows.shape[1] < 2:
|
||||||
|
return np.zeros((0, 0), dtype=np.float32), np.zeros((0, 0), dtype=np.float32)
|
||||||
|
x = cont_windows[:, :-1, :].reshape(cont_windows.shape[0], -1).astype(np.float32)
|
||||||
|
y = cont_windows[:, -1, :].astype(np.float32)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
set_random_seed(args.seed)
|
||||||
|
device = resolve_device(args.device, verbose=False)
|
||||||
|
|
||||||
|
cfg = {}
|
||||||
|
if args.config and Path(args.config).exists():
|
||||||
|
cfg = load_json(args.config)
|
||||||
|
seq_len = int(args.seq_len or cfg.get("sample_seq_len", cfg.get("seq_len", 96)))
|
||||||
|
stride = int(args.stride or seq_len)
|
||||||
|
max_rows_per_file = args.max_rows_per_file if args.max_rows_per_file > 0 else None
|
||||||
|
|
||||||
|
_, cont_cols, disc_cols, label_cols = load_split_columns(args.split)
|
||||||
|
train_paths = resolve_reference_paths(args.reference)
|
||||||
|
test_paths = infer_test_paths(args.reference)
|
||||||
|
vocab, vocab_sizes = load_vocab(args.vocab, disc_cols)
|
||||||
|
mean_vec, std_vec = load_stats_vectors(args.stats, cont_cols)
|
||||||
|
|
||||||
|
train_cont, train_disc, _ = load_windows_from_paths(
|
||||||
|
train_paths,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
seq_len=seq_len,
|
||||||
|
vocab=vocab,
|
||||||
|
label_cols=None,
|
||||||
|
stride=stride,
|
||||||
|
max_windows=args.max_train_windows,
|
||||||
|
max_rows_per_file=max_rows_per_file,
|
||||||
|
)
|
||||||
|
gen_cont, gen_disc, _ = load_windows_from_paths(
|
||||||
|
[args.generated],
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
seq_len=seq_len,
|
||||||
|
vocab=vocab,
|
||||||
|
label_cols=None,
|
||||||
|
stride=seq_len,
|
||||||
|
max_windows=args.max_generated_windows,
|
||||||
|
max_rows_per_file=max_rows_per_file,
|
||||||
|
)
|
||||||
|
test_cont, test_disc, test_labels = load_windows_from_paths(
|
||||||
|
test_paths,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
seq_len=seq_len,
|
||||||
|
vocab=vocab,
|
||||||
|
label_cols=label_cols,
|
||||||
|
stride=stride,
|
||||||
|
max_windows=args.max_test_windows,
|
||||||
|
max_rows_per_file=max_rows_per_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
if gen_cont.shape[0] == 0:
|
||||||
|
raise SystemExit("generated.csv did not contain enough rows for one evaluation window")
|
||||||
|
|
||||||
|
if test_labels is None or test_labels.size == 0:
|
||||||
|
split_at = max(1, int(round(train_cont.shape[0] * 0.8)))
|
||||||
|
test_cont = train_cont[split_at:]
|
||||||
|
test_disc = train_disc[split_at:]
|
||||||
|
test_labels = np.zeros(test_cont.shape[0], dtype=np.int64)
|
||||||
|
train_cont = train_cont[:split_at]
|
||||||
|
train_disc = train_disc[:split_at]
|
||||||
|
|
||||||
|
normal_test_mask = test_labels == 0
|
||||||
|
normal_test_cont = test_cont[normal_test_mask] if normal_test_mask.any() else test_cont
|
||||||
|
normal_test_disc = test_disc[normal_test_mask] if normal_test_mask.any() else test_disc
|
||||||
|
|
||||||
|
flat_train = build_flat_window_vectors(train_cont, train_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
flat_gen = build_flat_window_vectors(gen_cont, gen_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
flat_test = build_flat_window_vectors(test_cont, test_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
hist_train = build_histogram_embeddings(train_cont, train_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
hist_gen = build_histogram_embeddings(gen_cont, gen_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
hist_test = build_histogram_embeddings(test_cont, test_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
hist_normal_test = build_histogram_embeddings(normal_test_cont, normal_test_disc, mean_vec, std_vec, vocab_sizes)
|
||||||
|
|
||||||
|
cont_train_flat = standardize_cont_windows(train_cont, mean_vec, std_vec).reshape(train_cont.shape[0], -1)
|
||||||
|
cont_gen_flat = standardize_cont_windows(gen_cont, mean_vec, std_vec).reshape(gen_cont.shape[0], -1)
|
||||||
|
|
||||||
|
train_rows = flatten_rows(train_cont)
|
||||||
|
gen_rows = flatten_rows(gen_cont)
|
||||||
|
|
||||||
|
balanced = min(hist_train.shape[0], hist_gen.shape[0])
|
||||||
|
idx_real = sample_indices(hist_train.shape[0], balanced, args.seed)
|
||||||
|
idx_gen = sample_indices(hist_gen.shape[0], balanced, args.seed + 1)
|
||||||
|
discrim_x = np.concatenate([flat_train[idx_real], flat_gen[idx_gen]], axis=0) if balanced > 0 else np.zeros((0, flat_train.shape[1]), dtype=np.float32)
|
||||||
|
discrim_y = np.concatenate(
|
||||||
|
[np.zeros(balanced, dtype=np.int64), np.ones(balanced, dtype=np.int64)],
|
||||||
|
axis=0,
|
||||||
|
) if balanced > 0 else np.zeros((0,), dtype=np.int64)
|
||||||
|
if discrim_x.shape[0] > 0:
|
||||||
|
x_train, y_train, x_val, y_val = split_train_val(discrim_x, discrim_y, seed=args.seed)
|
||||||
|
discriminative = train_classifier(
|
||||||
|
x_train,
|
||||||
|
y_train,
|
||||||
|
x_val,
|
||||||
|
y_val,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.classifier_epochs,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
discriminative = {"accuracy": float("nan"), "balanced_accuracy": float("nan"), "auroc": float("nan")}
|
||||||
|
|
||||||
|
mmd_cont, gamma_cont = rbf_mmd(cont_train_flat[idx_real] if balanced > 0 else cont_train_flat, cont_gen_flat[idx_gen] if balanced > 0 else cont_gen_flat)
|
||||||
|
mmd_hist, gamma_hist = rbf_mmd(hist_train[idx_real] if balanced > 0 else hist_train, hist_gen[idx_gen] if balanced > 0 else hist_gen)
|
||||||
|
|
||||||
|
holdout_base = hist_normal_test if hist_normal_test.shape[0] > 0 else hist_test
|
||||||
|
nn_gen = nearest_neighbor_distance_stats(hist_gen, hist_train)
|
||||||
|
nn_holdout = nearest_neighbor_distance_stats(holdout_base, hist_train)
|
||||||
|
diversity = {
|
||||||
|
"duplicate_rate": duplicate_rate(flat_gen),
|
||||||
|
"exact_match_rate_to_train": exact_match_rate(flat_gen, flat_train),
|
||||||
|
"nn_gen_to_train_mean": nn_gen["mean"],
|
||||||
|
"nn_holdout_to_train_mean": nn_holdout["mean"],
|
||||||
|
"memorization_ratio": float(nn_gen["mean"] / max(nn_holdout["mean"], 1e-8)) if not math.isnan(nn_gen["mean"]) and not math.isnan(nn_holdout["mean"]) else float("nan"),
|
||||||
|
"one_nn_two_sample_accuracy": one_nn_two_sample_accuracy(holdout_base, hist_gen),
|
||||||
|
}
|
||||||
|
|
||||||
|
corr_real = compute_corr_matrix(train_rows)
|
||||||
|
corr_gen = compute_corr_matrix(gen_rows)
|
||||||
|
lag_corr_real = lagged_corr_from_windows(train_cont)
|
||||||
|
lag_corr_gen = lagged_corr_from_windows(gen_cont)
|
||||||
|
coupling = {
|
||||||
|
"corr_mean_abs_diff": mean_abs_matrix_diff(corr_real, corr_gen),
|
||||||
|
"corr_frobenius": fro_matrix_diff(corr_real, corr_gen),
|
||||||
|
"lag1_corr_mean_abs_diff": mean_abs_matrix_diff(lag_corr_real, lag_corr_gen),
|
||||||
|
"lag1_corr_frobenius": fro_matrix_diff(lag_corr_real, lag_corr_gen),
|
||||||
|
"by_process": {},
|
||||||
|
}
|
||||||
|
process_groups = split_process_groups(cont_cols)
|
||||||
|
for process_name, indices in process_groups.items():
|
||||||
|
real_block = corr_real[np.ix_(indices, indices)]
|
||||||
|
gen_block = corr_gen[np.ix_(indices, indices)]
|
||||||
|
real_lag_block = lag_corr_real[np.ix_(indices, indices)]
|
||||||
|
gen_lag_block = lag_corr_gen[np.ix_(indices, indices)]
|
||||||
|
coupling["by_process"][process_name] = {
|
||||||
|
"corr_mean_abs_diff": mean_abs_matrix_diff(real_block, gen_block),
|
||||||
|
"corr_frobenius": fro_matrix_diff(real_block, gen_block),
|
||||||
|
"lag1_corr_mean_abs_diff": mean_abs_matrix_diff(real_lag_block, gen_lag_block),
|
||||||
|
"lag1_corr_frobenius": fro_matrix_diff(real_lag_block, gen_lag_block),
|
||||||
|
}
|
||||||
|
|
||||||
|
frequency = psd_distance_stats(compute_average_psd(train_cont), compute_average_psd(gen_cont))
|
||||||
|
|
||||||
|
cfg_types = {
|
||||||
|
"type1": list(cfg.get("type1_features", []) or []),
|
||||||
|
"type2": list(cfg.get("type2_features", []) or []),
|
||||||
|
"type3": list(cfg.get("type3_features", []) or []),
|
||||||
|
"type4": list(cfg.get("type4_features", []) or []),
|
||||||
|
"type5": list(cfg.get("type5_features", []) or []),
|
||||||
|
"type6": list(cfg.get("type6_features", []) or []),
|
||||||
|
}
|
||||||
|
type_metrics = {
|
||||||
|
"type1_program": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type1"], dwell_and_steps),
|
||||||
|
"type2_controller": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type2"], controller_stats, use_real_bounds=True),
|
||||||
|
"type3_actuator": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type3"], actuator_stats),
|
||||||
|
"type4_pv": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type4"], pv_stats),
|
||||||
|
"type5_program_proxy": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type5"], dwell_and_steps),
|
||||||
|
"type6_aux": summarize_type_metrics(cont_cols, gen_rows, train_rows, cfg_types["type6"], aux_stats),
|
||||||
|
}
|
||||||
|
|
||||||
|
pred_train_x, pred_train_y = build_predictive_pairs(standardize_cont_windows(train_cont, mean_vec, std_vec))
|
||||||
|
pred_gen_x, pred_gen_y = build_predictive_pairs(standardize_cont_windows(gen_cont, mean_vec, std_vec))
|
||||||
|
pred_eval_x, pred_eval_y = build_predictive_pairs(standardize_cont_windows(normal_test_cont, mean_vec, std_vec))
|
||||||
|
predictive = {
|
||||||
|
"real_only": train_regressor(
|
||||||
|
pred_train_x,
|
||||||
|
pred_train_y,
|
||||||
|
pred_eval_x,
|
||||||
|
pred_eval_y,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.predictor_epochs,
|
||||||
|
seed=args.seed,
|
||||||
|
),
|
||||||
|
"synthetic_only": train_regressor(
|
||||||
|
pred_gen_x,
|
||||||
|
pred_gen_y,
|
||||||
|
pred_eval_x,
|
||||||
|
pred_eval_y,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.predictor_epochs,
|
||||||
|
seed=args.seed + 1,
|
||||||
|
),
|
||||||
|
"real_plus_synthetic": train_regressor(
|
||||||
|
np.concatenate([pred_train_x, bootstrap_to_size(pred_gen_x, pred_train_x.shape[0], args.seed + 2)], axis=0) if pred_train_x.size and pred_gen_x.size else pred_train_x,
|
||||||
|
np.concatenate([pred_train_y, bootstrap_to_size(pred_gen_y, pred_train_y.shape[0], args.seed + 2)], axis=0) if pred_train_y.size and pred_gen_y.size else pred_train_y,
|
||||||
|
pred_eval_x,
|
||||||
|
pred_eval_y,
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.predictor_epochs,
|
||||||
|
seed=args.seed + 2,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
predictive["rmse_ratio_synth_to_real"] = (
|
||||||
|
float(predictive["synthetic_only"]["rmse"] / max(predictive["real_only"]["rmse"], 1e-8))
|
||||||
|
if not math.isnan(predictive["synthetic_only"]["rmse"]) and not math.isnan(predictive["real_only"]["rmse"])
|
||||||
|
else float("nan")
|
||||||
|
)
|
||||||
|
|
||||||
|
target_size = max(flat_train.shape[0], min(512, flat_test.shape[0])) if flat_train.shape[0] > 0 else flat_gen.shape[0]
|
||||||
|
utility = {
|
||||||
|
"real_only": train_autoencoder(
|
||||||
|
flat_train,
|
||||||
|
flat_test,
|
||||||
|
test_labels.astype(np.int64),
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.detector_epochs,
|
||||||
|
seed=args.seed,
|
||||||
|
threshold_quantile=args.detector_threshold_quantile,
|
||||||
|
),
|
||||||
|
"synthetic_only": train_autoencoder(
|
||||||
|
bootstrap_to_size(flat_gen, target_size, args.seed + 3),
|
||||||
|
flat_test,
|
||||||
|
test_labels.astype(np.int64),
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.detector_epochs,
|
||||||
|
seed=args.seed + 3,
|
||||||
|
threshold_quantile=args.detector_threshold_quantile,
|
||||||
|
),
|
||||||
|
"real_plus_synthetic": train_autoencoder(
|
||||||
|
np.concatenate([flat_train, bootstrap_to_size(flat_gen, flat_train.shape[0], args.seed + 4)], axis=0) if flat_train.size and flat_gen.size else flat_train,
|
||||||
|
flat_test,
|
||||||
|
test_labels.astype(np.int64),
|
||||||
|
device=device,
|
||||||
|
hidden_dim=args.hidden_dim,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.detector_epochs,
|
||||||
|
seed=args.seed + 4,
|
||||||
|
threshold_quantile=args.detector_threshold_quantile,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_name = "eval_post.json" if Path(args.generated).name.startswith("generated_post") else "eval.json"
|
||||||
|
eval_candidate = Path(args.generated).with_name(eval_name)
|
||||||
|
basic_eval = load_json(eval_candidate) if eval_candidate.exists() else {}
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"generated_path": str(Path(args.generated).resolve()),
|
||||||
|
"reference_paths": train_paths,
|
||||||
|
"test_paths": test_paths,
|
||||||
|
"seq_len": seq_len,
|
||||||
|
"stride": stride,
|
||||||
|
"counts": {
|
||||||
|
"train_windows": int(train_cont.shape[0]),
|
||||||
|
"generated_windows": int(gen_cont.shape[0]),
|
||||||
|
"test_windows": int(test_cont.shape[0]),
|
||||||
|
"test_anomalous_windows": int(test_labels.sum()),
|
||||||
|
"test_normal_windows": int(normal_test_cont.shape[0]),
|
||||||
|
},
|
||||||
|
"basic_eval": {
|
||||||
|
"avg_ks": basic_eval.get("avg_ks"),
|
||||||
|
"avg_jsd": basic_eval.get("avg_jsd"),
|
||||||
|
"avg_lag1_diff": basic_eval.get("avg_lag1_diff"),
|
||||||
|
},
|
||||||
|
"two_sample": {
|
||||||
|
"continuous_mmd_rbf": mmd_cont,
|
||||||
|
"continuous_mmd_gamma": gamma_cont,
|
||||||
|
"histogram_mmd_rbf": mmd_hist,
|
||||||
|
"histogram_mmd_gamma": gamma_hist,
|
||||||
|
"discriminative_accuracy": discriminative["accuracy"],
|
||||||
|
"discriminative_balanced_accuracy": discriminative["balanced_accuracy"],
|
||||||
|
"discriminative_auroc": discriminative["auroc"],
|
||||||
|
},
|
||||||
|
"diversity_privacy": diversity,
|
||||||
|
"coupling": coupling,
|
||||||
|
"frequency": frequency,
|
||||||
|
"type_metrics": type_metrics,
|
||||||
|
"predictive_consistency": predictive,
|
||||||
|
"anomaly_utility": utility,
|
||||||
|
}
|
||||||
|
|
||||||
|
out_path = Path(args.out)
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path.write_text(json.dumps(out, indent=2), encoding="utf-8")
|
||||||
|
print("wrote", out_path)
|
||||||
|
print("generated_windows", gen_cont.shape[0])
|
||||||
|
print("continuous_mmd_rbf", mmd_cont)
|
||||||
|
print("discriminative_accuracy", discriminative["accuracy"])
|
||||||
|
print("utility_real_plus_synthetic_auprc", utility["real_plus_synthetic"]["auprc"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,44 +1,66 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Prepare vocab and normalization stats for HAI 21.03."""
|
"""Prepare vocab and normalization stats for HAI-style CSV datasets."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
||||||
from platform_utils import safe_path, ensure_dir
|
from platform_utils import safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
REPO_DIR = BASE_DIR.parent.parent
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
|
|
||||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
|
||||||
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
|
||||||
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
|
||||||
|
|
||||||
|
|
||||||
def main(max_rows: Optional[int] = None):
|
def parse_args():
|
||||||
config_path = BASE_DIR / "config.json"
|
parser = argparse.ArgumentParser(description="Prepare vocab and normalization stats.")
|
||||||
use_quantile = False
|
parser.add_argument("--config", default=str(BASE_DIR / "config.json"), help="Path to JSON config")
|
||||||
quantile_bins = None
|
parser.add_argument("--max-rows", type=int, default=50000, help="Sample cap for stats; ignored when full_stats=true")
|
||||||
full_stats = False
|
return parser.parse_args()
|
||||||
if config_path.exists():
|
|
||||||
|
|
||||||
|
def resolve_data_paths(cfg: dict, cfg_path: Path) -> list[str]:
|
||||||
|
base_dir = cfg_path.parent
|
||||||
|
data_glob = cfg.get("data_glob", "")
|
||||||
|
data_path = cfg.get("data_path", "")
|
||||||
|
paths = []
|
||||||
|
if data_glob:
|
||||||
|
resolved_glob = resolve_path(base_dir, data_glob)
|
||||||
|
paths = sorted(Path(resolved_glob).parent.glob(Path(resolved_glob).name))
|
||||||
|
elif data_path:
|
||||||
|
resolved_path = resolve_path(base_dir, data_path)
|
||||||
|
if Path(resolved_path).exists():
|
||||||
|
paths = [Path(resolved_path)]
|
||||||
|
return [safe_path(p) for p in paths]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
config_path = Path(args.config)
|
||||||
|
if not config_path.is_absolute():
|
||||||
|
config_path = resolve_path(BASE_DIR, config_path)
|
||||||
|
if not config_path.exists():
|
||||||
|
raise SystemExit(f"missing config: {config_path}")
|
||||||
|
|
||||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||||
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
||||||
full_stats = bool(cfg.get("full_stats", False))
|
full_stats = bool(cfg.get("full_stats", False))
|
||||||
|
max_rows: Optional[int] = args.max_rows
|
||||||
|
|
||||||
if full_stats:
|
if full_stats:
|
||||||
max_rows = None
|
max_rows = None
|
||||||
|
|
||||||
split = load_split(safe_path(SPLIT_PATH))
|
split_path = resolve_path(config_path.parent, cfg.get("split_path", "./feature_split.json"))
|
||||||
|
split = load_split(safe_path(split_path))
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
|
data_paths = resolve_data_paths(cfg, config_path)
|
||||||
if not data_paths:
|
if not data_paths:
|
||||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
raise SystemExit(f"no train files found for config: {config_path}")
|
||||||
data_paths = [safe_path(p) for p in data_paths]
|
|
||||||
|
|
||||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||||
cont_stats = compute_cont_stats(
|
cont_stats = compute_cont_stats(
|
||||||
@@ -50,8 +72,12 @@ def main(max_rows: Optional[int] = None):
|
|||||||
)
|
)
|
||||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||||
|
|
||||||
ensure_dir(OUT_STATS.parent)
|
out_stats = resolve_path(config_path.parent, cfg.get("stats_path", "./results/cont_stats.json"))
|
||||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
out_vocab = resolve_path(config_path.parent, cfg.get("vocab_path", "./results/disc_vocab.json"))
|
||||||
|
ensure_dir(out_stats.parent)
|
||||||
|
ensure_dir(out_vocab.parent)
|
||||||
|
|
||||||
|
with open(safe_path(out_stats), "w", encoding="utf-8") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"mean": cont_stats["mean"],
|
"mean": cont_stats["mean"],
|
||||||
@@ -73,10 +99,9 @@ def main(max_rows: Optional[int] = None):
|
|||||||
indent=2,
|
indent=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
|
with open(safe_path(out_vocab), "w", encoding="utf-8") as f:
|
||||||
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
|
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Default: sample 50000 rows for speed. Set to None for full scan.
|
main()
|
||||||
main(max_rows=50000)
|
|
||||||
|
|||||||
224
example/run_ablations.py
Normal file
224
example/run_ablations.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Run a default ablation suite and summarize results."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from platform_utils import safe_path, is_windows
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ABLATIONS = {
|
||||||
|
"full": {},
|
||||||
|
"no_temporal": {
|
||||||
|
"use_temporal_stage1": False,
|
||||||
|
},
|
||||||
|
"no_quantile": {
|
||||||
|
"use_quantile_transform": False,
|
||||||
|
"cont_post_calibrate": False,
|
||||||
|
"full_stats": False,
|
||||||
|
},
|
||||||
|
"no_post_calibration": {
|
||||||
|
"cont_post_calibrate": False,
|
||||||
|
},
|
||||||
|
"no_file_condition": {
|
||||||
|
"use_condition": False,
|
||||||
|
},
|
||||||
|
"no_type_routing": {
|
||||||
|
"type1_features": [],
|
||||||
|
"type2_features": [],
|
||||||
|
"type3_features": [],
|
||||||
|
"type4_features": [],
|
||||||
|
"type5_features": [],
|
||||||
|
"type6_features": [],
|
||||||
|
},
|
||||||
|
"no_snr_weight": {
|
||||||
|
"snr_weighted_loss": False,
|
||||||
|
},
|
||||||
|
"no_quantile_loss": {
|
||||||
|
"quantile_loss_weight": 0.0,
|
||||||
|
},
|
||||||
|
"no_residual_stat": {
|
||||||
|
"residual_stat_weight": 0.0,
|
||||||
|
},
|
||||||
|
"eps_target": {
|
||||||
|
"cont_target": "eps",
|
||||||
|
"cont_clamp_x0": 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser = argparse.ArgumentParser(description="Run ablation experiments.")
|
||||||
|
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
||||||
|
parser.add_argument("--device", default="auto")
|
||||||
|
parser.add_argument("--variants", default="", help="comma-separated variant names; empty uses defaults")
|
||||||
|
parser.add_argument("--seeds", default="", help="comma-separated seeds; empty uses config seed")
|
||||||
|
parser.add_argument("--out-root", default=str(base_dir / "results" / "ablations"))
|
||||||
|
parser.add_argument("--skip-prepare", action="store_true")
|
||||||
|
parser.add_argument("--skip-train", action="store_true")
|
||||||
|
parser.add_argument("--skip-export", action="store_true")
|
||||||
|
parser.add_argument("--skip-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-comprehensive-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-postprocess", action="store_true")
|
||||||
|
parser.add_argument("--skip-post-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-diagnostics", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def run(cmd: List[str]) -> None:
|
||||||
|
print("running:", " ".join(cmd))
|
||||||
|
cmd = [safe_path(arg) for arg in cmd]
|
||||||
|
if is_windows():
|
||||||
|
subprocess.run(cmd, check=True, shell=False)
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: Path) -> Dict:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(path: Path, obj: Dict) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(obj, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def selected_variants(arg: str) -> List[str]:
|
||||||
|
if not arg:
|
||||||
|
return list(DEFAULT_ABLATIONS.keys())
|
||||||
|
names = [name.strip() for name in arg.split(",") if name.strip()]
|
||||||
|
unknown = [name for name in names if name not in DEFAULT_ABLATIONS]
|
||||||
|
if unknown:
|
||||||
|
raise SystemExit(f"unknown ablation names: {', '.join(unknown)}")
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def parse_seeds(arg: str, cfg: Dict) -> List[int]:
|
||||||
|
if not arg:
|
||||||
|
return [int(cfg.get("seed", 1337))]
|
||||||
|
return [int(item.strip()) for item in arg.split(",") if item.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_metrics(run_dir: Path) -> Dict[str, float]:
|
||||||
|
out: Dict[str, float] = {}
|
||||||
|
eval_path = run_dir / "eval.json"
|
||||||
|
if eval_path.exists():
|
||||||
|
data = load_json(eval_path)
|
||||||
|
out["avg_ks"] = data.get("avg_ks")
|
||||||
|
out["avg_jsd"] = data.get("avg_jsd")
|
||||||
|
out["avg_lag1_diff"] = data.get("avg_lag1_diff")
|
||||||
|
comp_path = run_dir / "comprehensive_eval.json"
|
||||||
|
if comp_path.exists():
|
||||||
|
data = load_json(comp_path)
|
||||||
|
out["continuous_mmd_rbf"] = data.get("two_sample", {}).get("continuous_mmd_rbf")
|
||||||
|
out["discriminative_accuracy"] = data.get("two_sample", {}).get("discriminative_accuracy")
|
||||||
|
out["corr_mean_abs_diff"] = data.get("coupling", {}).get("corr_mean_abs_diff")
|
||||||
|
out["avg_psd_l1"] = data.get("frequency", {}).get("avg_psd_l1")
|
||||||
|
out["memorization_ratio"] = data.get("diversity_privacy", {}).get("memorization_ratio")
|
||||||
|
out["predictive_rmse_real"] = data.get("predictive_consistency", {}).get("real_only", {}).get("rmse")
|
||||||
|
out["predictive_rmse_synth"] = data.get("predictive_consistency", {}).get("synthetic_only", {}).get("rmse")
|
||||||
|
out["utility_auprc_real"] = data.get("anomaly_utility", {}).get("real_only", {}).get("auprc")
|
||||||
|
out["utility_auprc_synth"] = data.get("anomaly_utility", {}).get("synthetic_only", {}).get("auprc")
|
||||||
|
out["utility_auprc_aug"] = data.get("anomaly_utility", {}).get("real_plus_synthetic", {}).get("auprc")
|
||||||
|
post_eval_path = run_dir / "eval_post.json"
|
||||||
|
if post_eval_path.exists():
|
||||||
|
post = load_json(post_eval_path)
|
||||||
|
out["post_avg_ks"] = post.get("avg_ks")
|
||||||
|
out["post_avg_jsd"] = post.get("avg_jsd")
|
||||||
|
out["post_avg_lag1_diff"] = post.get("avg_lag1_diff")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
out_root = Path(args.out_root)
|
||||||
|
out_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_path = Path(args.config)
|
||||||
|
if not config_path.is_absolute():
|
||||||
|
config_path = (base_dir / config_path).resolve()
|
||||||
|
base_cfg = load_json(config_path)
|
||||||
|
variants = selected_variants(args.variants)
|
||||||
|
seeds = parse_seeds(args.seeds, base_cfg)
|
||||||
|
|
||||||
|
generated_configs: Dict[str, Path] = {}
|
||||||
|
for variant in variants:
|
||||||
|
cfg = dict(base_cfg)
|
||||||
|
cfg.update(DEFAULT_ABLATIONS[variant])
|
||||||
|
cfg_path = out_root / "configs" / f"{variant}.json"
|
||||||
|
write_json(cfg_path, cfg)
|
||||||
|
generated_configs[variant] = cfg_path
|
||||||
|
|
||||||
|
history_path = out_root / "benchmark_history.csv"
|
||||||
|
summary_path = out_root / "benchmark_summary.csv"
|
||||||
|
runs_root = out_root / "runs"
|
||||||
|
rows: List[Dict[str, object]] = []
|
||||||
|
|
||||||
|
for variant in variants:
|
||||||
|
cfg_path = generated_configs[variant]
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "run_all.py"),
|
||||||
|
"--config",
|
||||||
|
str(cfg_path),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
"--runs-root",
|
||||||
|
str(runs_root),
|
||||||
|
"--benchmark-history",
|
||||||
|
str(history_path),
|
||||||
|
"--benchmark-summary",
|
||||||
|
str(summary_path),
|
||||||
|
"--seeds",
|
||||||
|
",".join(str(seed) for seed in seeds),
|
||||||
|
]
|
||||||
|
if args.skip_train:
|
||||||
|
cmd.append("--skip-train")
|
||||||
|
if args.skip_prepare:
|
||||||
|
cmd.append("--skip-prepare")
|
||||||
|
if args.skip_export:
|
||||||
|
cmd.append("--skip-export")
|
||||||
|
if args.skip_eval:
|
||||||
|
cmd.append("--skip-eval")
|
||||||
|
if args.skip_comprehensive_eval:
|
||||||
|
cmd.append("--skip-comprehensive-eval")
|
||||||
|
if args.skip_postprocess:
|
||||||
|
cmd.append("--skip-postprocess")
|
||||||
|
if args.skip_post_eval:
|
||||||
|
cmd.append("--skip-post-eval")
|
||||||
|
if args.skip_diagnostics:
|
||||||
|
cmd.append("--skip-diagnostics")
|
||||||
|
run(cmd)
|
||||||
|
|
||||||
|
for seed in seeds:
|
||||||
|
run_dir = runs_root / f"{cfg_path.stem}__seed{seed}"
|
||||||
|
row: Dict[str, object] = {"variant": variant, "seed": seed, "run_dir": str(run_dir)}
|
||||||
|
row.update(collect_metrics(run_dir))
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
fieldnames = sorted({key for row in rows for key in row.keys()})
|
||||||
|
csv_path = out_root / "ablation_summary.csv"
|
||||||
|
with csv_path.open("w", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
json_path = out_root / "ablation_summary.json"
|
||||||
|
write_json(json_path, {"variants": variants, "seeds": seeds, "rows": rows})
|
||||||
|
print("wrote", csv_path)
|
||||||
|
print("wrote", json_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -38,6 +38,7 @@ def parse_args():
|
|||||||
parser.add_argument("--skip-train", action="store_true")
|
parser.add_argument("--skip-train", action="store_true")
|
||||||
parser.add_argument("--skip-export", action="store_true")
|
parser.add_argument("--skip-export", action="store_true")
|
||||||
parser.add_argument("--skip-eval", action="store_true")
|
parser.add_argument("--skip-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-comprehensive-eval", action="store_true")
|
||||||
parser.add_argument("--skip-postprocess", action="store_true")
|
parser.add_argument("--skip-postprocess", action="store_true")
|
||||||
parser.add_argument("--skip-post-eval", action="store_true")
|
parser.add_argument("--skip-post-eval", action="store_true")
|
||||||
parser.add_argument("--skip-diagnostics", action="store_true")
|
parser.add_argument("--skip-diagnostics", action="store_true")
|
||||||
@@ -212,14 +213,13 @@ def main():
|
|||||||
config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)]
|
config_paths = expand_config_args(base_dir, args.configs) if args.configs else [resolve_config_path(base_dir, args.config)]
|
||||||
batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary)
|
batch_mode = bool(args.configs or args.seeds or (args.repeat and args.repeat > 1) or args.runs_root or args.benchmark_history or args.benchmark_summary)
|
||||||
|
|
||||||
if not args.skip_prepare:
|
|
||||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
|
||||||
|
|
||||||
runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs")
|
runs_root = Path(args.runs_root) if args.runs_root else (base_dir / "results" / "runs")
|
||||||
history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv")
|
history_out = Path(args.benchmark_history) if args.benchmark_history else (base_dir / "results" / "benchmark_history.csv")
|
||||||
summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv")
|
summary_out = Path(args.benchmark_summary) if args.benchmark_summary else (base_dir / "results" / "benchmark_summary.csv")
|
||||||
|
|
||||||
for config_path in config_paths:
|
for config_path in config_paths:
|
||||||
|
if not args.skip_prepare:
|
||||||
|
run([sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)])
|
||||||
cfg_base = config_path.parent
|
cfg_base = config_path.parent
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
@@ -351,6 +351,29 @@ def main():
|
|||||||
str(base_dir / "results" / "metrics_history.csv"),
|
str(base_dir / "results" / "metrics_history.csv"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if not args.skip_comprehensive_eval:
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_comprehensive.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "comprehensive_eval.json"),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if not args.skip_postprocess:
|
if not args.skip_postprocess:
|
||||||
cmd = [
|
cmd = [
|
||||||
@@ -387,6 +410,29 @@ def main():
|
|||||||
if ref:
|
if ref:
|
||||||
cmd += ["--reference", str(ref)]
|
cmd += ["--reference", str(ref)]
|
||||||
run(cmd)
|
run(cmd)
|
||||||
|
if not args.skip_comprehensive_eval:
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_comprehensive.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated_post.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "comprehensive_eval_post.json"),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if not args.skip_diagnostics:
|
if not args.skip_diagnostics:
|
||||||
if ref:
|
if ref:
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def main():
|
|||||||
clip_k = cfg.get("clip_k", 5.0)
|
clip_k = cfg.get("clip_k", 5.0)
|
||||||
|
|
||||||
if not args.skip_prepare:
|
if not args.skip_prepare:
|
||||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
run([sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)])
|
||||||
if not args.skip_train:
|
if not args.skip_train:
|
||||||
run([sys.executable, str(base_dir / "train.py"), "--config", str(config_path), "--device", args.device])
|
run([sys.executable, str(base_dir / "train.py"), "--config", str(config_path), "--device", args.device])
|
||||||
if not args.skip_export:
|
if not args.skip_export:
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def main():
|
|||||||
clip_k = cfg.get("clip_k", 5.0)
|
clip_k = cfg.get("clip_k", 5.0)
|
||||||
data_glob = cfg.get("data_glob", "")
|
data_glob = cfg.get("data_glob", "")
|
||||||
data_path = cfg.get("data_path", "")
|
data_path = cfg.get("data_path", "")
|
||||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
run([sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)])
|
||||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||||
run(
|
run(
|
||||||
[
|
[
|
||||||
|
|||||||
65
example/window_models.py
Normal file
65
example/window_models.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Small neural models used by the evaluation suite."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class MLPClassifier(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int = 256, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
mid_dim = max(hidden_dim // 2, 32)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, mid_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(mid_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.net(x).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPRegressor(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 256, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
mid_dim = max(hidden_dim // 2, 32)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, mid_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(mid_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPAutoencoder(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int = 256, latent_dim: int = 64, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, latent_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Linear(latent_dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, input_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.decoder(self.encoder(x))
|
||||||
|
|
||||||
Reference in New Issue
Block a user