连续型特征在时许相关性上的不足

This commit is contained in:
2026-01-23 15:06:52 +08:00
parent 0d17be9a1c
commit ff12324560
12 changed files with 1212 additions and 68 deletions

View File

@@ -5,8 +5,9 @@ import argparse
import csv
import gzip
import json
import math
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict, Tuple, List, Optional
def load_json(path: str) -> Dict:
@@ -28,6 +29,8 @@ def parse_args():
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
return parser.parse_args()
@@ -55,6 +58,62 @@ def finalize_stats(stats):
return out
def js_divergence(p, q, eps: float = 1e-12) -> float:
p = [max(x, eps) for x in p]
q = [max(x, eps) for x in q]
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
def kl(a, b):
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
def ks_statistic(x: List[float], y: List[float]) -> float:
if not x or not y:
return 0.0
x_sorted = sorted(x)
y_sorted = sorted(y)
n = len(x_sorted)
m = len(y_sorted)
i = j = 0
cdf_x = cdf_y = 0.0
d = 0.0
while i < n and j < m:
if x_sorted[i] <= y_sorted[j]:
i += 1
cdf_x = i / n
else:
j += 1
cdf_y = j / m
d = max(d, abs(cdf_x - cdf_y))
return d
def lag1_corr(values: List[float]) -> float:
if len(values) < 3:
return 0.0
x = values[:-1]
y = values[1:]
mean_x = sum(x) / len(x)
mean_y = sum(y) / len(y)
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
den_x = sum((xi - mean_x) ** 2 for xi in x)
den_y = sum((yi - mean_y) ** 2 for yi in y)
if den_x <= 0 or den_y <= 0:
return 0.0
return num / math.sqrt(den_x * den_y)
def resolve_reference_path(path: str) -> Optional[str]:
if not path:
return None
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent
pat = Path(path).name
matches = sorted(base.glob(pat))
return str(matches[0]) if matches else None
return str(path)
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
@@ -63,13 +122,18 @@ def main():
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
if args.reference and not Path(args.reference).is_absolute():
args.reference = str((base_dir / args.reference).resolve())
ref_path = resolve_reference_path(args.reference)
split = load_json(args.split)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats_ref = load_json(args.stats)["mean"]
std_ref = load_json(args.stats)["std"]
stats_json = load_json(args.stats)
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
@@ -89,6 +153,8 @@ def main():
except Exception:
v = 0.0
update_stats(cont_stats, c, v)
if ref_path:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
disc_invalid[c] += 1
@@ -112,6 +178,81 @@ def main():
"discrete_invalid_counts": disc_invalid,
}
# Optional richer metrics using reference data
if ref_path:
ref_cont = {c: [] for c in cont_cols}
ref_disc = {c: {} for c in disc_cols}
gen_cont = {c: [] for c in cont_cols}
gen_disc = {c: {} for c in disc_cols}
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
gen_cont[c].append(float(row[c]))
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
if args.max_rows and i + 1 >= args.max_rows:
break
# Continuous metrics: KS + quantiles + lag1 correlation
cont_ks = {}
cont_quant = {}
cont_lag1 = {}
for c in cont_cols:
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
ref_sorted = sorted(ref_cont[c])
gen_sorted = sorted(gen_cont[c])
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
def qval(arr, q):
if not arr:
return 0.0
idx = int(q * (len(arr) - 1))
return arr[idx]
cont_quant[c] = {
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
}
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
# Discrete metrics: JSD over vocab
disc_jsd = {}
for c in disc_cols:
vocab_vals = list(vocab_sets[c])
gen_total = sum(gen_disc[c].values()) or 1
ref_total = sum(ref_disc[c].values()) or 1
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
disc_jsd[c] = js_divergence(p, q)
report["continuous_ks"] = cont_ks
report["continuous_quantile_diff"] = cont_quant
report["continuous_lag1_diff"] = cont_lag1
report["discrete_jsd"] = disc_jsd
with open(args.out, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)