连续型特征在时许相关性上的不足
This commit is contained in:
@@ -5,8 +5,9 @@ import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
|
||||
|
||||
def load_json(path: str) -> Dict:
|
||||
@@ -28,6 +29,8 @@ def parse_args():
|
||||
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
|
||||
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -55,6 +58,62 @@ def finalize_stats(stats):
|
||||
return out
|
||||
|
||||
|
||||
def js_divergence(p, q, eps: float = 1e-12) -> float:
|
||||
p = [max(x, eps) for x in p]
|
||||
q = [max(x, eps) for x in q]
|
||||
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
|
||||
def kl(a, b):
|
||||
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
|
||||
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
|
||||
|
||||
|
||||
def ks_statistic(x: List[float], y: List[float]) -> float:
|
||||
if not x or not y:
|
||||
return 0.0
|
||||
x_sorted = sorted(x)
|
||||
y_sorted = sorted(y)
|
||||
n = len(x_sorted)
|
||||
m = len(y_sorted)
|
||||
i = j = 0
|
||||
cdf_x = cdf_y = 0.0
|
||||
d = 0.0
|
||||
while i < n and j < m:
|
||||
if x_sorted[i] <= y_sorted[j]:
|
||||
i += 1
|
||||
cdf_x = i / n
|
||||
else:
|
||||
j += 1
|
||||
cdf_y = j / m
|
||||
d = max(d, abs(cdf_x - cdf_y))
|
||||
return d
|
||||
|
||||
|
||||
def lag1_corr(values: List[float]) -> float:
|
||||
if len(values) < 3:
|
||||
return 0.0
|
||||
x = values[:-1]
|
||||
y = values[1:]
|
||||
mean_x = sum(x) / len(x)
|
||||
mean_y = sum(y) / len(y)
|
||||
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
|
||||
den_x = sum((xi - mean_x) ** 2 for xi in x)
|
||||
den_y = sum((yi - mean_y) ** 2 for yi in y)
|
||||
if den_x <= 0 or den_y <= 0:
|
||||
return 0.0
|
||||
return num / math.sqrt(den_x * den_y)
|
||||
|
||||
|
||||
def resolve_reference_path(path: str) -> Optional[str]:
|
||||
if not path:
|
||||
return None
|
||||
if any(ch in path for ch in ["*", "?", "["]):
|
||||
base = Path(path).parent
|
||||
pat = Path(path).name
|
||||
matches = sorted(base.glob(pat))
|
||||
return str(matches[0]) if matches else None
|
||||
return str(path)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
@@ -63,13 +122,18 @@ def main():
|
||||
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
|
||||
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
|
||||
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
|
||||
if args.reference and not Path(args.reference).is_absolute():
|
||||
args.reference = str((base_dir / args.reference).resolve())
|
||||
ref_path = resolve_reference_path(args.reference)
|
||||
split = load_json(args.split)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats_ref = load_json(args.stats)["mean"]
|
||||
std_ref = load_json(args.stats)["std"]
|
||||
stats_json = load_json(args.stats)
|
||||
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
|
||||
std_ref = stats_json.get("raw_std", stats_json.get("std"))
|
||||
transforms = stats_json.get("transform", {})
|
||||
vocab = load_json(args.vocab)["vocab"]
|
||||
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||
|
||||
@@ -89,6 +153,8 @@ def main():
|
||||
except Exception:
|
||||
v = 0.0
|
||||
update_stats(cont_stats, c, v)
|
||||
if ref_path:
|
||||
pass
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
disc_invalid[c] += 1
|
||||
@@ -112,6 +178,81 @@ def main():
|
||||
"discrete_invalid_counts": disc_invalid,
|
||||
}
|
||||
|
||||
# Optional richer metrics using reference data
|
||||
if ref_path:
|
||||
ref_cont = {c: [] for c in cont_cols}
|
||||
ref_disc = {c: {} for c in disc_cols}
|
||||
gen_cont = {c: [] for c in cont_cols}
|
||||
gen_disc = {c: {} for c in disc_cols}
|
||||
|
||||
with open_csv(args.generated) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
gen_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
gen_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
||||
|
||||
with open_csv(ref_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
ref_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
if args.max_rows and i + 1 >= args.max_rows:
|
||||
break
|
||||
|
||||
# Continuous metrics: KS + quantiles + lag1 correlation
|
||||
cont_ks = {}
|
||||
cont_quant = {}
|
||||
cont_lag1 = {}
|
||||
for c in cont_cols:
|
||||
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
|
||||
ref_sorted = sorted(ref_cont[c])
|
||||
gen_sorted = sorted(gen_cont[c])
|
||||
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
|
||||
def qval(arr, q):
|
||||
if not arr:
|
||||
return 0.0
|
||||
idx = int(q * (len(arr) - 1))
|
||||
return arr[idx]
|
||||
cont_quant[c] = {
|
||||
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
|
||||
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
|
||||
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
|
||||
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
|
||||
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
|
||||
}
|
||||
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
|
||||
|
||||
# Discrete metrics: JSD over vocab
|
||||
disc_jsd = {}
|
||||
for c in disc_cols:
|
||||
vocab_vals = list(vocab_sets[c])
|
||||
gen_total = sum(gen_disc[c].values()) or 1
|
||||
ref_total = sum(ref_disc[c].values()) or 1
|
||||
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
|
||||
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
|
||||
disc_jsd[c] = js_divergence(p, q)
|
||||
|
||||
report["continuous_ks"] = cont_ks
|
||||
report["continuous_quantile_diff"] = cont_quant
|
||||
report["continuous_lag1_diff"] = cont_lag1
|
||||
report["discrete_jsd"] = disc_jsd
|
||||
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user