From dc78f336bc8ee7f426781b4d26da77089ae554c2 Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Tue, 27 Jan 2026 18:19:07 +0800 Subject: [PATCH] update2 --- docs/README.md | 3 + docs/decisions.md | 6 + example/diagnose_ks.py | 227 ++++++++++++++++++++++ example/results/data_shift_plot.svg | 43 ++++ example/results/data_shift_plot_full.svg | 43 ++++ example/results/data_shift_stats.csv | 4 + example/results/data_shift_stats_full.csv | 4 + report.md | 5 + 8 files changed, 335 insertions(+) create mode 100644 example/diagnose_ks.py create mode 100644 example/results/data_shift_plot.svg create mode 100644 example/results/data_shift_plot_full.svg create mode 100644 example/results/data_shift_stats.csv create mode 100644 example/results/data_shift_stats_full.csv diff --git a/docs/README.md b/docs/README.md index 581720c..dc6b8d4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,3 +10,6 @@ Conventions: - Append new entries instead of overwriting old ones. - Record exact config file and key overrides when possible. - Keep metrics in the order: avg_ks / avg_jsd / avg_lag1_diff. + +Tools: +- `example/diagnose_ks.py` for per-feature KS + CDF plots. diff --git a/docs/decisions.md b/docs/decisions.md index 3f9d45b..b93e2bd 100644 --- a/docs/decisions.md +++ b/docs/decisions.md @@ -33,3 +33,9 @@ - `example/sample.py` - `example/export_samples.py` - `example/config.json` + +## 2026-01-26 — Per-feature KS diagnostics +- **Decision**: Add a per-feature KS/CDF diagnostic script to pinpoint KS failures (tails, boundary pile-up, shifts). +- **Why**: Avoid blind reweighting and find the specific features causing KS to stay high. +- **Files**: + - `example/diagnose_ks.py` diff --git a/example/diagnose_ks.py b/example/diagnose_ks.py new file mode 100644 index 0000000..60aaae1 --- /dev/null +++ b/example/diagnose_ks.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +"""Per-feature KS diagnostics and CDF visualization (no third-party deps).""" + +import argparse +import csv +import gzip +import json +import math +from pathlib import Path +from glob import glob + + +def parse_args(): + parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.") + base_dir = Path(__file__).resolve().parent + parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv")) + parser.add_argument("--reference", default=str(base_dir / "config.json")) + parser.add_argument("--out-dir", default=str(base_dir / "results")) + parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan") + parser.add_argument("--stride", type=int, default=1, help="row stride sampling") + parser.add_argument("--top-k", type=int, default=8) + return parser.parse_args() + + +def load_split(base_dir: Path): + with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f: + split = json.load(f) + time_col = split.get("time_column", "time") + cont_cols = [c for c in split["continuous"] if c != time_col] + return cont_cols + + +def resolve_reference_glob(base_dir: Path, ref_arg: str): + ref_path = Path(ref_arg) + if ref_path.suffix == ".json": + cfg = json.loads(ref_path.read_text(encoding="utf-8")) + data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" + if not data_glob: + raise SystemExit("reference config has no data_glob/data_path") + ref_path = (ref_path.parent / data_glob).resolve() + return str(ref_path) + return str(ref_path) + + +def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True): + values = {c: [] for c in cols} + row_count = 0 + reader = None + if gz: + fh = gzip.open(path, "rt", newline="") + else: + fh = open(path, "r", newline="", encoding="utf-8") + try: + reader = csv.DictReader(fh) + for i, row in enumerate(reader): + if stride > 1 and i % stride != 0: + continue + for c in cols: + v = row.get(c, "") + try: + fv = float(v) + if math.isfinite(fv): + values[c].append(fv) + except Exception: + continue + row_count += 1 + if max_rows > 0 and row_count >= max_rows: + break + finally: + fh.close() + return values, row_count + + +def ks_statistic(a, b): + if not a or not b: + return 1.0 + a = sorted(a) + b = sorted(b) + na = len(a) + nb = len(b) + i = j = 0 + d = 0.0 + while i < na and j < nb: + if a[i] <= b[j]: + i += 1 + else: + j += 1 + fa = i / na + fb = j / nb + d = max(d, abs(fa - fb)) + return d + + +def ecdf_points(vals): + vals = sorted(vals) + n = len(vals) + if n == 0: + return [], [] + xs = [] + ys = [] + last = None + for i, v in enumerate(vals, 1): + if last is None or v != last: + xs.append(v) + ys.append(i / n) + last = v + else: + ys[-1] = i / n + return xs, ys + + +def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None): + width, height = 900, 420 + pad = 50 + panel_w = width - pad * 2 + panel_h = height - pad * 2 + if not real_vals or not gen_vals: + return + min_v = min(min(real_vals), min(gen_vals)) + max_v = max(max(real_vals), max(gen_vals)) + if max_v == min_v: + max_v += 1.0 + rx, ry = ecdf_points(real_vals) + gx, gy = ecdf_points(gen_vals) + + def sx(v): + return pad + int((v - min_v) * panel_w / (max_v - min_v)) + + def sy(v): + return pad + panel_h - int(v * panel_h) + + svg = [] + svg.append(f'') + svg.append('') + svg.append(f'CDF 비교: {feature}') + svg.append(f'') + svg.append(f'') + + def path_from(xs, ys, color): + pts = " ".join(f"{sx(x)},{sy(y)}" for x, y in zip(xs, ys)) + return f'' + + svg.append(path_from(rx, ry, "#1f77b4")) # real + svg.append(path_from(gx, gy, "#d62728")) # gen + svg.append(f'real') + svg.append(f'generated') + + if bounds is not None: + lo, hi = bounds + svg.append(f'') + svg.append(f'') + + svg.append("") + out_path.write_text("\n".join(svg), encoding="utf-8") + + +def main(): + args = parse_args() + base_dir = Path(__file__).resolve().parent + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cont_cols = load_split(base_dir) + ref_glob = resolve_reference_glob(base_dir, args.reference) + ref_files = sorted(glob(ref_glob)) + if not ref_files: + raise SystemExit(f"no reference files matched: {ref_glob}") + + gen_path = Path(args.generated) + gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False) + + # Reference values (aggregate across files) + real_vals = {c: [] for c in cont_cols} + total_rows = 0 + for f in ref_files: + vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True) + total_rows += rows + for c in cont_cols: + real_vals[c].extend(vals[c]) + + # KS per feature + ks_rows = [] + for c in cont_cols: + ks = ks_statistic(real_vals[c], gen_vals[c]) + # boundary pile-up (using real min/max) + if real_vals[c]: + lo = min(real_vals[c]) + hi = max(real_vals[c]) + else: + lo = hi = 0.0 + tol = (hi - lo) * 1e-4 if hi > lo else 1e-6 + gen = gen_vals[c] + if gen: + frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen) + frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen) + else: + frac_lo = frac_hi = 0.0 + ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi)) + + ks_rows.sort(key=lambda x: x[1], reverse=True) + out_csv = out_dir / "ks_per_feature.csv" + with out_csv.open("w", newline="") as fh: + w = csv.writer(fh) + w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"]) + for row in ks_rows: + w.writerow(row) + + # top-k CDF plots + for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]: + out_svg = out_dir / f"cdf_{c}.svg" + render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi)) + + # summary + summary = { + "generated_rows": gen_rows, + "reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full", + "stride": args.stride, + "top_k_features": [r[0] for r in ks_rows[: args.top_k]], + } + (out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") + + print(f"wrote {out_csv}") + print(f"wrote CDF svgs for top {args.top_k} features") + + +if __name__ == "__main__": + main() diff --git a/example/results/data_shift_plot.svg b/example/results/data_shift_plot.svg new file mode 100644 index 0000000..12038c4 --- /dev/null +++ b/example/results/data_shift_plot.svg @@ -0,0 +1,43 @@ + + +Per-file mean (sampled every 50 rows) +Per-file std (sampled every 50 rows) + + +54106.136 +1.370 +train1.csv.gz +train2.csv.gz +train3.csv.gz + +P1_FT01 + +P1_LIT01 + +P1_PIT01 + +P2_CO_rpm + +P3_LIT01 + +P4_ST_PT01 + + +4403.032 +0.053 +train1.csv.gz +train2.csv.gz +train3.csv.gz + +P1_FT01 + +P1_LIT01 + +P1_PIT01 + +P2_CO_rpm + +P3_LIT01 + +P4_ST_PT01 + \ No newline at end of file diff --git a/example/results/data_shift_plot_full.svg b/example/results/data_shift_plot_full.svg new file mode 100644 index 0000000..a8681f1 --- /dev/null +++ b/example/results/data_shift_plot_full.svg @@ -0,0 +1,43 @@ + + +Per-file mean (full data) +Per-file std (full data) + + +54105.286 +1.370 +train1.csv.gz +train2.csv.gz +train3.csv.gz + +P1_FT01 + +P1_LIT01 + +P1_PIT01 + +P2_CO_rpm + +P3_LIT01 + +P4_ST_PT01 + + +4409.026 +0.056 +train1.csv.gz +train2.csv.gz +train3.csv.gz + +P1_FT01 + +P1_LIT01 + +P1_PIT01 + +P2_CO_rpm + +P3_LIT01 + +P4_ST_PT01 + \ No newline at end of file diff --git a/example/results/data_shift_stats.csv b/example/results/data_shift_stats.csv new file mode 100644 index 0000000..61e6642 --- /dev/null +++ b/example/results/data_shift_stats.csv @@ -0,0 +1,4 @@ +file,sample_rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01 +train1.csv.gz,4321,197.3625650034721,402.4983721615365,1.3701480791483485,54106.13631103911,13582.384170330943,10048.28187919463,25.78211851952459,15.226246029512001,0.05818374391855817,20.603665756241284,4258.450859219057,16.684991885589945 +train2.csv.gz,4537,180.380231487768,404.9985957328643,1.3885507626184759,54097.111968260964,13492.448038351333,10052.902358386598,26.01082617737414,11.712980419233306,0.07154540840596779,21.507712300569843,4301.554272480435,20.81453649334688 +train3.csv.gz,9577,200.5677786457133,404.42271995927814,1.3747714273780822,54105.243552260625,13194.925368069333,10050.422313876998,25.685144220465013,17.301905753283886,0.05304501736117375,21.759589175104754,4403.031534167203,15.627077960123735 diff --git a/example/results/data_shift_stats_full.csv b/example/results/data_shift_stats_full.csv new file mode 100644 index 0000000..fec073e --- /dev/null +++ b/example/results/data_shift_stats_full.csv @@ -0,0 +1,4 @@ +file,rows,mean_P1_FT01,mean_P1_LIT01,mean_P1_PIT01,mean_P2_CO_rpm,mean_P3_LIT01,mean_P4_ST_PT01,std_P1_FT01,std_P1_LIT01,std_P1_PIT01,std_P2_CO_rpm,std_P3_LIT01,std_P4_ST_PT01 +train1.csv.gz,216001,197.4039421067864,402.499530273871,1.3699407928204004,54105.28641186847,13576.493878500563,10048.201686566266,25.725535038417004,15.22218627802303,0.05860248629393732,21.16202209482823,4264.6609588197925,16.8563584796801 +train2.csv.gz,226801,180.30764089694264,404.99552063838513,1.3890573547293745,54095.33532480016,13482.01015758308,10052.992377679111,26.231835534624313,11.708752202389721,0.07297888433587706,22.315550737159924,4308.3181358982065,21.239090989129597 +train3.csv.gz,478801,200.34315271607556,404.42644943316503,1.3753994816009067,54105.041362173426,13194.615527118782,10050.500895988103,26.111739592931507,17.298124670143782,0.05566509425248764,22.327041460872234,4409.0260830855495,16.17524060043852 diff --git a/report.md b/report.md index 799e611..64e78a5 100644 --- a/report.md +++ b/report.md @@ -180,6 +180,11 @@ Metrics (with reference): - 追加记录到 `example/results/metrics_history.csv` - 如果存在上一次记录,输出 delta(新旧对比) +**分布诊断脚本(逐特征 KS/CDF):** `example/diagnose_ks.py` +- 输出 `example/results/ks_per_feature.csv`(每个连续特征 KS) +- 输出 `example/results/cdf_.svg`(真实 vs 生成 CDF) +- 统计生成数据是否堆积在边界(gen_frac_at_min / gen_frac_at_max) + Recent run (user-reported, Windows): - avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984