This commit is contained in:
2026-01-27 18:19:07 +08:00
parent 513dce2f4c
commit dc78f336bc
8 changed files with 335 additions and 0 deletions

View File

@@ -10,3 +10,6 @@ Conventions:
- Append new entries instead of overwriting old ones. - Append new entries instead of overwriting old ones.
- Record exact config file and key overrides when possible. - Record exact config file and key overrides when possible.
- Keep metrics in the order: avg_ks / avg_jsd / avg_lag1_diff. - Keep metrics in the order: avg_ks / avg_jsd / avg_lag1_diff.
Tools:
- `example/diagnose_ks.py` for per-feature KS + CDF plots.

View File

@@ -33,3 +33,9 @@
- `example/sample.py` - `example/sample.py`
- `example/export_samples.py` - `example/export_samples.py`
- `example/config.json` - `example/config.json`
## 2026-01-26 — Per-feature KS diagnostics
- **Decision**: Add a per-feature KS/CDF diagnostic script to pinpoint KS failures (tails, boundary pile-up, shifts).
- **Why**: Avoid blind reweighting and find the specific features causing KS to stay high.
- **Files**:
- `example/diagnose_ks.py`

227
example/diagnose_ks.py Normal file
View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""Per-feature KS diagnostics and CDF visualization (no third-party deps)."""
import argparse
import csv
import gzip
import json
import math
from pathlib import Path
from glob import glob
def parse_args():
parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--reference", default=str(base_dir / "config.json"))
parser.add_argument("--out-dir", default=str(base_dir / "results"))
parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan")
parser.add_argument("--stride", type=int, default=1, help="row stride sampling")
parser.add_argument("--top-k", type=int, default=8)
return parser.parse_args()
def load_split(base_dir: Path):
with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f:
split = json.load(f)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
return cont_cols
def resolve_reference_glob(base_dir: Path, ref_arg: str):
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
ref_path = (ref_path.parent / data_glob).resolve()
return str(ref_path)
return str(ref_path)
def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True):
values = {c: [] for c in cols}
row_count = 0
reader = None
if gz:
fh = gzip.open(path, "rt", newline="")
else:
fh = open(path, "r", newline="", encoding="utf-8")
try:
reader = csv.DictReader(fh)
for i, row in enumerate(reader):
if stride > 1 and i % stride != 0:
continue
for c in cols:
v = row.get(c, "")
try:
fv = float(v)
if math.isfinite(fv):
values[c].append(fv)
except Exception:
continue
row_count += 1
if max_rows > 0 and row_count >= max_rows:
break
finally:
fh.close()
return values, row_count
def ks_statistic(a, b):
if not a or not b:
return 1.0
a = sorted(a)
b = sorted(b)
na = len(a)
nb = len(b)
i = j = 0
d = 0.0
while i < na and j < nb:
if a[i] <= b[j]:
i += 1
else:
j += 1
fa = i / na
fb = j / nb
d = max(d, abs(fa - fb))
return d
def ecdf_points(vals):
vals = sorted(vals)
n = len(vals)
if n == 0:
return [], []
xs = []
ys = []
last = None
for i, v in enumerate(vals, 1):
if last is None or v != last:
xs.append(v)
ys.append(i / n)
last = v
else:
ys[-1] = i / n
return xs, ys
def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None):
width, height = 900, 420
pad = 50
panel_w = width - pad * 2
panel_h = height - pad * 2
if not real_vals or not gen_vals:
return
min_v = min(min(real_vals), min(gen_vals))
max_v = max(max(real_vals), max(gen_vals))
if max_v == min_v:
max_v += 1.0
rx, ry = ecdf_points(real_vals)
gx, gy = ecdf_points(gen_vals)
def sx(v):
return pad + int((v - min_v) * panel_w / (max_v - min_v))
def sy(v):
return pad + panel_h - int(v * panel_h)
svg = []
svg.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">')
svg.append('<style>text{font-family:Arial,sans-serif;font-size:12px}</style>')
svg.append(f'<text x="{pad}" y="{pad-20}">CDF 비교: {feature}</text>')
svg.append(f'<line x1="{pad}" y1="{pad}" x2="{pad}" y2="{pad+panel_h}" stroke="#333"/>')
svg.append(f'<line x1="{pad}" y1="{pad+panel_h}" x2="{pad+panel_w}" y2="{pad+panel_h}" stroke="#333"/>')
def path_from(xs, ys, color):
pts = " ".join(f"{sx(x)},{sy(y)}" for x, y in zip(xs, ys))
return f'<polyline fill="none" stroke="{color}" stroke-width="2" points="{pts}"/>'
svg.append(path_from(rx, ry, "#1f77b4")) # real
svg.append(path_from(gx, gy, "#d62728")) # gen
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+15}" fill="#1f77b4">real</text>')
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+30}" fill="#d62728">generated</text>')
if bounds is not None:
lo, hi = bounds
svg.append(f'<line x1="{sx(lo)}" y1="{pad}" x2="{sx(lo)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
svg.append(f'<line x1="{sx(hi)}" y1="{pad}" x2="{sx(hi)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
svg.append("</svg>")
out_path.write_text("\n".join(svg), encoding="utf-8")
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
cont_cols = load_split(base_dir)
ref_glob = resolve_reference_glob(base_dir, args.reference)
ref_files = sorted(glob(ref_glob))
if not ref_files:
raise SystemExit(f"no reference files matched: {ref_glob}")
gen_path = Path(args.generated)
gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False)
# Reference values (aggregate across files)
real_vals = {c: [] for c in cont_cols}
total_rows = 0
for f in ref_files:
vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True)
total_rows += rows
for c in cont_cols:
real_vals[c].extend(vals[c])
# KS per feature
ks_rows = []
for c in cont_cols:
ks = ks_statistic(real_vals[c], gen_vals[c])
# boundary pile-up (using real min/max)
if real_vals[c]:
lo = min(real_vals[c])
hi = max(real_vals[c])
else:
lo = hi = 0.0
tol = (hi - lo) * 1e-4 if hi > lo else 1e-6
gen = gen_vals[c]
if gen:
frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen)
frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen)
else:
frac_lo = frac_hi = 0.0
ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi))
ks_rows.sort(key=lambda x: x[1], reverse=True)
out_csv = out_dir / "ks_per_feature.csv"
with out_csv.open("w", newline="") as fh:
w = csv.writer(fh)
w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"])
for row in ks_rows:
w.writerow(row)
# top-k CDF plots
for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]:
out_svg = out_dir / f"cdf_{c}.svg"
render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi))
# summary
summary = {
"generated_rows": gen_rows,
"reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full",
"stride": args.stride,
"top_k_features": [r[0] for r in ks_rows[: args.top_k]],
}
(out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(f"wrote {out_csv}")
print(f"wrote CDF svgs for top {args.top_k} features")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,43 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="700">
<style>text{font-family:Arial,sans-serif;font-size:12px}</style>
<text x="60" y="40">Per-file mean (sampled every 50 rows)</text>
<text x="60" y="360">Per-file std (sampled every 50 rows)</text>
<line x1="60" y1="60" x2="60" y2="320" stroke="#333"/>
<line x1="60" y1="320" x2="1140" y2="320" stroke="#333"/>
<text x="50" y="70" text-anchor="end">54106.136</text>
<text x="50" y="320" text-anchor="end">1.370</text>
<text x="60" y="335" text-anchor="middle" transform="rotate(45 60 335)">train1.csv.gz</text>
<text x="600" y="335" text-anchor="middle" transform="rotate(45 600 335)">train2.csv.gz</text>
<text x="1140" y="335" text-anchor="middle" transform="rotate(45 1140 335)">train3.csv.gz</text>
<polyline fill="none" stroke="#1f77b4" stroke-width="2" points="60,320 600,320 1140,320"/>
<text x="1150" y="75" fill="#1f77b4">P1_FT01</text>
<polyline fill="none" stroke="#ff7f0e" stroke-width="2" points="60,319 600,319 1140,319"/>
<text x="1150" y="90" fill="#ff7f0e">P1_LIT01</text>
<polyline fill="none" stroke="#2ca02c" stroke-width="2" points="60,320 600,320 1140,320"/>
<text x="1150" y="105" fill="#2ca02c">P1_PIT01</text>
<polyline fill="none" stroke="#d62728" stroke-width="2" points="60,60 600,61 1140,61"/>
<text x="1150" y="120" fill="#d62728">P2_CO_rpm</text>
<polyline fill="none" stroke="#9467bd" stroke-width="2" points="60,255 600,256 1140,257"/>
<text x="1150" y="135" fill="#9467bd">P3_LIT01</text>
<polyline fill="none" stroke="#8c564b" stroke-width="2" points="60,272 600,272 1140,272"/>
<text x="1150" y="150" fill="#8c564b">P4_ST_PT01</text>
<line x1="60" y1="380" x2="60" y2="640" stroke="#333"/>
<line x1="60" y1="640" x2="1140" y2="640" stroke="#333"/>
<text x="50" y="390" text-anchor="end">4403.032</text>
<text x="50" y="640" text-anchor="end">0.053</text>
<text x="60" y="655" text-anchor="middle" transform="rotate(45 60 655)">train1.csv.gz</text>
<text x="600" y="655" text-anchor="middle" transform="rotate(45 600 655)">train2.csv.gz</text>
<text x="1140" y="655" text-anchor="middle" transform="rotate(45 1140 655)">train3.csv.gz</text>
<polyline fill="none" stroke="#1f77b4" stroke-width="2" points="60,639 600,639 1140,639"/>
<text x="1150" y="395" fill="#1f77b4">P1_FT01</text>
<polyline fill="none" stroke="#ff7f0e" stroke-width="2" points="60,640 600,640 1140,639"/>
<text x="1150" y="410" fill="#ff7f0e">P1_LIT01</text>
<polyline fill="none" stroke="#2ca02c" stroke-width="2" points="60,640 600,640 1140,640"/>
<text x="1150" y="425" fill="#2ca02c">P1_PIT01</text>
<polyline fill="none" stroke="#d62728" stroke-width="2" points="60,639 600,639 1140,639"/>
<text x="1150" y="440" fill="#d62728">P2_CO_rpm</text>
<polyline fill="none" stroke="#9467bd" stroke-width="2" points="60,389 600,386 1140,380"/>
<text x="1150" y="455" fill="#9467bd">P3_LIT01</text>
<polyline fill="none" stroke="#8c564b" stroke-width="2" points="60,640 600,639 1140,640"/>
<text x="1150" y="470" fill="#8c564b">P4_ST_PT01</text>
</svg>

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

@@ -0,0 +1,43 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="700">
<style>text{font-family:Arial,sans-serif;font-size:12px}</style>
<text x="60" y="40">Per-file mean (full data)</text>
<text x="60" y="360">Per-file std (full data)</text>
<line x1="60" y1="60" x2="60" y2="320" stroke="#333"/>
<line x1="60" y1="320" x2="1140" y2="320" stroke="#333"/>
<text x="50" y="70" text-anchor="end">54105.286</text>
<text x="50" y="320" text-anchor="end">1.370</text>
<text x="60" y="335" text-anchor="middle" transform="rotate(45 60 335)">train1.csv.gz</text>
<text x="600" y="335" text-anchor="middle" transform="rotate(45 600 335)">train2.csv.gz</text>
<text x="1140" y="335" text-anchor="middle" transform="rotate(45 1140 335)">train3.csv.gz</text>
<polyline fill="none" stroke="#1f77b4" stroke-width="2" points="60,320 600,320 1140,320"/>
<text x="1150" y="75" fill="#1f77b4">P1_FT01</text>
<polyline fill="none" stroke="#ff7f0e" stroke-width="2" points="60,319 600,319 1140,319"/>
<text x="1150" y="90" fill="#ff7f0e">P1_LIT01</text>
<polyline fill="none" stroke="#2ca02c" stroke-width="2" points="60,320 600,320 1140,320"/>
<text x="1150" y="105" fill="#2ca02c">P1_PIT01</text>
<polyline fill="none" stroke="#d62728" stroke-width="2" points="60,60 600,61 1140,61"/>
<text x="1150" y="120" fill="#d62728">P2_CO_rpm</text>
<polyline fill="none" stroke="#9467bd" stroke-width="2" points="60,255 600,256 1140,257"/>
<text x="1150" y="135" fill="#9467bd">P3_LIT01</text>
<polyline fill="none" stroke="#8c564b" stroke-width="2" points="60,272 600,272 1140,272"/>
<text x="1150" y="150" fill="#8c564b">P4_ST_PT01</text>
<line x1="60" y1="380" x2="60" y2="640" stroke="#333"/>
<line x1="60" y1="640" x2="1140" y2="640" stroke="#333"/>
<text x="50" y="390" text-anchor="end">4409.026</text>
<text x="50" y="640" text-anchor="end">0.056</text>
<text x="60" y="655" text-anchor="middle" transform="rotate(45 60 655)">train1.csv.gz</text>
<text x="600" y="655" text-anchor="middle" transform="rotate(45 600 655)">train2.csv.gz</text>
<text x="1140" y="655" text-anchor="middle" transform="rotate(45 1140 655)">train3.csv.gz</text>
<polyline fill="none" stroke="#1f77b4" stroke-width="2" points="60,639 600,639 1140,639"/>
<text x="1150" y="395" fill="#1f77b4">P1_FT01</text>
<polyline fill="none" stroke="#ff7f0e" stroke-width="2" points="60,640 600,640 1140,639"/>
<text x="1150" y="410" fill="#ff7f0e">P1_LIT01</text>
<polyline fill="none" stroke="#2ca02c" stroke-width="2" points="60,640 600,640 1140,640"/>
<text x="1150" y="425" fill="#2ca02c">P1_PIT01</text>
<polyline fill="none" stroke="#d62728" stroke-width="2" points="60,639 600,639 1140,639"/>
<text x="1150" y="440" fill="#d62728">P2_CO_rpm</text>
<polyline fill="none" stroke="#9467bd" stroke-width="2" points="60,389 600,386 1140,380"/>
<text x="1150" y="455" fill="#9467bd">P3_LIT01</text>
<polyline fill="none" stroke="#8c564b" stroke-width="2" points="60,640 600,639 1140,640"/>
<text x="1150" y="470" fill="#8c564b">P4_ST_PT01</text>
</svg>

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

@@ -0,0 +1,4 @@
file,sample_rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01
train1.csv.gz,4321,197.3625650034721,402.4983721615365,1.3701480791483485,54106.13631103911,13582.384170330943,10048.28187919463,25.78211851952459,15.226246029512001,0.05818374391855817,20.603665756241284,4258.450859219057,16.684991885589945
train2.csv.gz,4537,180.380231487768,404.9985957328643,1.3885507626184759,54097.111968260964,13492.448038351333,10052.902358386598,26.01082617737414,11.712980419233306,0.07154540840596779,21.507712300569843,4301.554272480435,20.81453649334688
train3.csv.gz,9577,200.5677786457133,404.42271995927814,1.3747714273780822,54105.243552260625,13194.925368069333,10050.422313876998,25.685144220465013,17.301905753283886,0.05304501736117375,21.759589175104754,4403.031534167203,15.627077960123735
1 file sample_rows mean_P1_FT01 mean_P1_LIT01 mean_P1_PIT01 mean_P2_CO_rpm mean_P3_LIT01 mean_P4_ST_PT01 std_P1_FT01 std_P1_LIT01 std_P1_PIT01 std_P2_CO_rpm std_P3_LIT01 std_P4_ST_PT01
2 train1.csv.gz 4321 197.3625650034721 402.4983721615365 1.3701480791483485 54106.13631103911 13582.384170330943 10048.28187919463 25.78211851952459 15.226246029512001 0.05818374391855817 20.603665756241284 4258.450859219057 16.684991885589945
3 train2.csv.gz 4537 180.380231487768 404.9985957328643 1.3885507626184759 54097.111968260964 13492.448038351333 10052.902358386598 26.01082617737414 11.712980419233306 0.07154540840596779 21.507712300569843 4301.554272480435 20.81453649334688
4 train3.csv.gz 9577 200.5677786457133 404.42271995927814 1.3747714273780822 54105.243552260625 13194.925368069333 10050.422313876998 25.685144220465013 17.301905753283886 0.05304501736117375 21.759589175104754 4403.031534167203 15.627077960123735

View File

@@ -0,0 +1,4 @@
file,rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01
train1.csv.gz,216001,197.4039421067864,402.499530273871,1.3699407928204004,54105.28641186847,13576.493878500563,10048.201686566266,25.725535038417004,15.22218627802303,0.05860248629393732,21.16202209482823,4264.6609588197925,16.8563584796801
train2.csv.gz,226801,180.30764089694264,404.99552063838513,1.3890573547293745,54095.33532480016,13482.01015758308,10052.992377679111,26.231835534624313,11.708752202389721,0.07297888433587706,22.315550737159924,4308.3181358982065,21.239090989129597
train3.csv.gz,478801,200.34315271607556,404.42644943316503,1.3753994816009067,54105.041362173426,13194.615527118782,10050.500895988103,26.111739592931507,17.298124670143782,0.05566509425248764,22.327041460872234,4409.0260830855495,16.17524060043852
1 file rows mean_P1_FT01 mean_P1_LIT01 mean_P1_PIT01 mean_P2_CO_rpm mean_P3_LIT01 mean_P4_ST_PT01 std_P1_FT01 std_P1_LIT01 std_P1_PIT01 std_P2_CO_rpm std_P3_LIT01 std_P4_ST_PT01
2 train1.csv.gz 216001 197.4039421067864 402.499530273871 1.3699407928204004 54105.28641186847 13576.493878500563 10048.201686566266 25.725535038417004 15.22218627802303 0.05860248629393732 21.16202209482823 4264.6609588197925 16.8563584796801
3 train2.csv.gz 226801 180.30764089694264 404.99552063838513 1.3890573547293745 54095.33532480016 13482.01015758308 10052.992377679111 26.231835534624313 11.708752202389721 0.07297888433587706 22.315550737159924 4308.3181358982065 21.239090989129597
4 train3.csv.gz 478801 200.34315271607556 404.42644943316503 1.3753994816009067 54105.041362173426 13194.615527118782 10050.500895988103 26.111739592931507 17.298124670143782 0.05566509425248764 22.327041460872234 4409.0260830855495 16.17524060043852

View File

@@ -180,6 +180,11 @@ Metrics (with reference):
- 追加记录到 `example/results/metrics_history.csv` - 追加记录到 `example/results/metrics_history.csv`
- 如果存在上一次记录,输出 delta新旧对比 - 如果存在上一次记录,输出 delta新旧对比
**分布诊断脚本(逐特征 KS/CDF** `example/diagnose_ks.py`
- 输出 `example/results/ks_per_feature.csv`(每个连续特征 KS
- 输出 `example/results/cdf_<feature>.svg`(真实 vs 生成 CDF
- 统计生成数据是否堆积在边界gen_frac_at_min / gen_frac_at_max
Recent run (user-reported, Windows): Recent run (user-reported, Windows):
- avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984 - avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984