diff --git a/docs/README.md b/docs/README.md
index 581720c..dc6b8d4 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -10,3 +10,6 @@ Conventions:
- Append new entries instead of overwriting old ones.
- Record exact config file and key overrides when possible.
- Keep metrics in the order: avg_ks / avg_jsd / avg_lag1_diff.
+
+Tools:
+- `example/diagnose_ks.py` for per-feature KS + CDF plots.
diff --git a/docs/decisions.md b/docs/decisions.md
index 3f9d45b..b93e2bd 100644
--- a/docs/decisions.md
+++ b/docs/decisions.md
@@ -33,3 +33,9 @@
- `example/sample.py`
- `example/export_samples.py`
- `example/config.json`
+
+## 2026-01-26 — Per-feature KS diagnostics
+- **Decision**: Add a per-feature KS/CDF diagnostic script to pinpoint KS failures (tails, boundary pile-up, shifts).
+- **Why**: Avoid blind reweighting and find the specific features causing KS to stay high.
+- **Files**:
+ - `example/diagnose_ks.py`
diff --git a/example/diagnose_ks.py b/example/diagnose_ks.py
new file mode 100644
index 0000000..60aaae1
--- /dev/null
+++ b/example/diagnose_ks.py
@@ -0,0 +1,227 @@
+#!/usr/bin/env python3
+"""Per-feature KS diagnostics and CDF visualization (no third-party deps)."""
+
+import argparse
+import csv
+import gzip
+import json
+import math
+from pathlib import Path
+from glob import glob
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.")
+ base_dir = Path(__file__).resolve().parent
+ parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
+ parser.add_argument("--reference", default=str(base_dir / "config.json"))
+ parser.add_argument("--out-dir", default=str(base_dir / "results"))
+ parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan")
+ parser.add_argument("--stride", type=int, default=1, help="row stride sampling")
+ parser.add_argument("--top-k", type=int, default=8)
+ return parser.parse_args()
+
+
+def load_split(base_dir: Path):
+ with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f:
+ split = json.load(f)
+ time_col = split.get("time_column", "time")
+ cont_cols = [c for c in split["continuous"] if c != time_col]
+ return cont_cols
+
+
+def resolve_reference_glob(base_dir: Path, ref_arg: str):
+ ref_path = Path(ref_arg)
+ if ref_path.suffix == ".json":
+ cfg = json.loads(ref_path.read_text(encoding="utf-8"))
+ data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
+ if not data_glob:
+ raise SystemExit("reference config has no data_glob/data_path")
+ ref_path = (ref_path.parent / data_glob).resolve()
+ return str(ref_path)
+ return str(ref_path)
+
+
+def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True):
+ values = {c: [] for c in cols}
+ row_count = 0
+ reader = None
+ if gz:
+ fh = gzip.open(path, "rt", newline="")
+ else:
+ fh = open(path, "r", newline="", encoding="utf-8")
+ try:
+ reader = csv.DictReader(fh)
+ for i, row in enumerate(reader):
+ if stride > 1 and i % stride != 0:
+ continue
+ for c in cols:
+ v = row.get(c, "")
+ try:
+ fv = float(v)
+ if math.isfinite(fv):
+ values[c].append(fv)
+ except Exception:
+ continue
+ row_count += 1
+ if max_rows > 0 and row_count >= max_rows:
+ break
+ finally:
+ fh.close()
+ return values, row_count
+
+
+def ks_statistic(a, b):
+ if not a or not b:
+ return 1.0
+ a = sorted(a)
+ b = sorted(b)
+ na = len(a)
+ nb = len(b)
+ i = j = 0
+ d = 0.0
+ while i < na and j < nb:
+ if a[i] <= b[j]:
+ i += 1
+ else:
+ j += 1
+ fa = i / na
+ fb = j / nb
+ d = max(d, abs(fa - fb))
+ return d
+
+
+def ecdf_points(vals):
+ vals = sorted(vals)
+ n = len(vals)
+ if n == 0:
+ return [], []
+ xs = []
+ ys = []
+ last = None
+ for i, v in enumerate(vals, 1):
+ if last is None or v != last:
+ xs.append(v)
+ ys.append(i / n)
+ last = v
+ else:
+ ys[-1] = i / n
+ return xs, ys
+
+
+def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None):
+ width, height = 900, 420
+ pad = 50
+ panel_w = width - pad * 2
+ panel_h = height - pad * 2
+ if not real_vals or not gen_vals:
+ return
+ min_v = min(min(real_vals), min(gen_vals))
+ max_v = max(max(real_vals), max(gen_vals))
+ if max_v == min_v:
+ max_v += 1.0
+ rx, ry = ecdf_points(real_vals)
+ gx, gy = ecdf_points(gen_vals)
+
+ def sx(v):
+ return pad + int((v - min_v) * panel_w / (max_v - min_v))
+
+ def sy(v):
+ return pad + panel_h - int(v * panel_h)
+
+ svg = []
+ svg.append(f'")
+ out_path.write_text("\n".join(svg), encoding="utf-8")
+
+
+def main():
+ args = parse_args()
+ base_dir = Path(__file__).resolve().parent
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ cont_cols = load_split(base_dir)
+ ref_glob = resolve_reference_glob(base_dir, args.reference)
+ ref_files = sorted(glob(ref_glob))
+ if not ref_files:
+ raise SystemExit(f"no reference files matched: {ref_glob}")
+
+ gen_path = Path(args.generated)
+ gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False)
+
+ # Reference values (aggregate across files)
+ real_vals = {c: [] for c in cont_cols}
+ total_rows = 0
+ for f in ref_files:
+ vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True)
+ total_rows += rows
+ for c in cont_cols:
+ real_vals[c].extend(vals[c])
+
+ # KS per feature
+ ks_rows = []
+ for c in cont_cols:
+ ks = ks_statistic(real_vals[c], gen_vals[c])
+ # boundary pile-up (using real min/max)
+ if real_vals[c]:
+ lo = min(real_vals[c])
+ hi = max(real_vals[c])
+ else:
+ lo = hi = 0.0
+ tol = (hi - lo) * 1e-4 if hi > lo else 1e-6
+ gen = gen_vals[c]
+ if gen:
+ frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen)
+ frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen)
+ else:
+ frac_lo = frac_hi = 0.0
+ ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi))
+
+ ks_rows.sort(key=lambda x: x[1], reverse=True)
+ out_csv = out_dir / "ks_per_feature.csv"
+ with out_csv.open("w", newline="") as fh:
+ w = csv.writer(fh)
+ w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"])
+ for row in ks_rows:
+ w.writerow(row)
+
+ # top-k CDF plots
+ for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]:
+ out_svg = out_dir / f"cdf_{c}.svg"
+ render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi))
+
+ # summary
+ summary = {
+ "generated_rows": gen_rows,
+ "reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full",
+ "stride": args.stride,
+ "top_k_features": [r[0] for r in ks_rows[: args.top_k]],
+ }
+ (out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
+
+ print(f"wrote {out_csv}")
+ print(f"wrote CDF svgs for top {args.top_k} features")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/example/results/data_shift_plot.svg b/example/results/data_shift_plot.svg
new file mode 100644
index 0000000..12038c4
--- /dev/null
+++ b/example/results/data_shift_plot.svg
@@ -0,0 +1,43 @@
+
\ No newline at end of file
diff --git a/example/results/data_shift_plot_full.svg b/example/results/data_shift_plot_full.svg
new file mode 100644
index 0000000..a8681f1
--- /dev/null
+++ b/example/results/data_shift_plot_full.svg
@@ -0,0 +1,43 @@
+
\ No newline at end of file
diff --git a/example/results/data_shift_stats.csv b/example/results/data_shift_stats.csv
new file mode 100644
index 0000000..61e6642
--- /dev/null
+++ b/example/results/data_shift_stats.csv
@@ -0,0 +1,4 @@
+file,sample_rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01
+train1.csv.gz,4321,197.3625650034721,402.4983721615365,1.3701480791483485,54106.13631103911,13582.384170330943,10048.28187919463,25.78211851952459,15.226246029512001,0.05818374391855817,20.603665756241284,4258.450859219057,16.684991885589945
+train2.csv.gz,4537,180.380231487768,404.9985957328643,1.3885507626184759,54097.111968260964,13492.448038351333,10052.902358386598,26.01082617737414,11.712980419233306,0.07154540840596779,21.507712300569843,4301.554272480435,20.81453649334688
+train3.csv.gz,9577,200.5677786457133,404.42271995927814,1.3747714273780822,54105.243552260625,13194.925368069333,10050.422313876998,25.685144220465013,17.301905753283886,0.05304501736117375,21.759589175104754,4403.031534167203,15.627077960123735
diff --git a/example/results/data_shift_stats_full.csv b/example/results/data_shift_stats_full.csv
new file mode 100644
index 0000000..fec073e
--- /dev/null
+++ b/example/results/data_shift_stats_full.csv
@@ -0,0 +1,4 @@
+file,rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01
+train1.csv.gz,216001,197.4039421067864,402.499530273871,1.3699407928204004,54105.28641186847,13576.493878500563,10048.201686566266,25.725535038417004,15.22218627802303,0.05860248629393732,21.16202209482823,4264.6609588197925,16.8563584796801
+train2.csv.gz,226801,180.30764089694264,404.99552063838513,1.3890573547293745,54095.33532480016,13482.01015758308,10052.992377679111,26.231835534624313,11.708752202389721,0.07297888433587706,22.315550737159924,4308.3181358982065,21.239090989129597
+train3.csv.gz,478801,200.34315271607556,404.42644943316503,1.3753994816009067,54105.041362173426,13194.615527118782,10050.500895988103,26.111739592931507,17.298124670143782,0.05566509425248764,22.327041460872234,4409.0260830855495,16.17524060043852
diff --git a/report.md b/report.md
index 799e611..64e78a5 100644
--- a/report.md
+++ b/report.md
@@ -180,6 +180,11 @@ Metrics (with reference):
- 追加记录到 `example/results/metrics_history.csv`
- 如果存在上一次记录,输出 delta(新旧对比)
+**分布诊断脚本(逐特征 KS/CDF):** `example/diagnose_ks.py`
+- 输出 `example/results/ks_per_feature.csv`(每个连续特征 KS)
+- 输出 `example/results/cdf_.svg`(真实 vs 生成 CDF)
+- 统计生成数据是否堆积在边界(gen_frac_at_min / gen_frac_at_max)
+
Recent run (user-reported, Windows):
- avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984