231 lines
7.7 KiB
Python
231 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Per-feature KS diagnostics and CDF visualization (no third-party deps)."""
|
|
|
|
import argparse
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import math
|
|
from pathlib import Path
|
|
from glob import glob
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
|
parser.add_argument("--reference", default=str(base_dir / "config.json"))
|
|
parser.add_argument("--out-dir", default=str(base_dir / "results"))
|
|
parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan")
|
|
parser.add_argument("--stride", type=int, default=1, help="row stride sampling")
|
|
parser.add_argument("--top-k", type=int, default=8)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_split(base_dir: Path):
|
|
with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f:
|
|
split = json.load(f)
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
return cont_cols
|
|
|
|
|
|
def resolve_reference_glob(base_dir: Path, ref_arg: str):
|
|
ref_path = Path(ref_arg)
|
|
if ref_path.suffix == ".json":
|
|
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
|
|
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
|
if not data_glob:
|
|
raise SystemExit("reference config has no data_glob/data_path")
|
|
combined = ref_path.parent / data_glob
|
|
# On Windows, Path.resolve fails on glob patterns like *.csv.gz
|
|
if "*" in str(combined) or "?" in str(combined):
|
|
return str(combined)
|
|
return str(combined.resolve())
|
|
return str(ref_path)
|
|
|
|
|
|
def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True):
|
|
values = {c: [] for c in cols}
|
|
row_count = 0
|
|
reader = None
|
|
if gz:
|
|
fh = gzip.open(path, "rt", newline="")
|
|
else:
|
|
fh = open(path, "r", newline="", encoding="utf-8")
|
|
try:
|
|
reader = csv.DictReader(fh)
|
|
for i, row in enumerate(reader):
|
|
if stride > 1 and i % stride != 0:
|
|
continue
|
|
for c in cols:
|
|
v = row.get(c, "")
|
|
try:
|
|
fv = float(v)
|
|
if math.isfinite(fv):
|
|
values[c].append(fv)
|
|
except Exception:
|
|
continue
|
|
row_count += 1
|
|
if max_rows > 0 and row_count >= max_rows:
|
|
break
|
|
finally:
|
|
fh.close()
|
|
return values, row_count
|
|
|
|
|
|
def ks_statistic(a, b):
|
|
if not a or not b:
|
|
return 1.0
|
|
a = sorted(a)
|
|
b = sorted(b)
|
|
na = len(a)
|
|
nb = len(b)
|
|
i = j = 0
|
|
d = 0.0
|
|
while i < na and j < nb:
|
|
if a[i] <= b[j]:
|
|
i += 1
|
|
else:
|
|
j += 1
|
|
fa = i / na
|
|
fb = j / nb
|
|
d = max(d, abs(fa - fb))
|
|
return d
|
|
|
|
|
|
def ecdf_points(vals):
|
|
vals = sorted(vals)
|
|
n = len(vals)
|
|
if n == 0:
|
|
return [], []
|
|
xs = []
|
|
ys = []
|
|
last = None
|
|
for i, v in enumerate(vals, 1):
|
|
if last is None or v != last:
|
|
xs.append(v)
|
|
ys.append(i / n)
|
|
last = v
|
|
else:
|
|
ys[-1] = i / n
|
|
return xs, ys
|
|
|
|
|
|
def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None):
|
|
width, height = 900, 420
|
|
pad = 50
|
|
panel_w = width - pad * 2
|
|
panel_h = height - pad * 2
|
|
if not real_vals or not gen_vals:
|
|
return
|
|
min_v = min(min(real_vals), min(gen_vals))
|
|
max_v = max(max(real_vals), max(gen_vals))
|
|
if max_v == min_v:
|
|
max_v += 1.0
|
|
rx, ry = ecdf_points(real_vals)
|
|
gx, gy = ecdf_points(gen_vals)
|
|
|
|
def sx(v):
|
|
return pad + int((v - min_v) * panel_w / (max_v - min_v))
|
|
|
|
def sy(v):
|
|
return pad + panel_h - int(v * panel_h)
|
|
|
|
svg = []
|
|
svg.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">')
|
|
svg.append('<style>text{font-family:Arial,sans-serif;font-size:12px}</style>')
|
|
svg.append(f'<text x="{pad}" y="{pad-20}">CDF 비교: {feature}</text>')
|
|
svg.append(f'<line x1="{pad}" y1="{pad}" x2="{pad}" y2="{pad+panel_h}" stroke="#333"/>')
|
|
svg.append(f'<line x1="{pad}" y1="{pad+panel_h}" x2="{pad+panel_w}" y2="{pad+panel_h}" stroke="#333"/>')
|
|
|
|
def path_from(xs, ys, color):
|
|
pts = " ".join(f"{sx(x)},{sy(y)}" for x, y in zip(xs, ys))
|
|
return f'<polyline fill="none" stroke="{color}" stroke-width="2" points="{pts}"/>'
|
|
|
|
svg.append(path_from(rx, ry, "#1f77b4")) # real
|
|
svg.append(path_from(gx, gy, "#d62728")) # gen
|
|
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+15}" fill="#1f77b4">real</text>')
|
|
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+30}" fill="#d62728">generated</text>')
|
|
|
|
if bounds is not None:
|
|
lo, hi = bounds
|
|
svg.append(f'<line x1="{sx(lo)}" y1="{pad}" x2="{sx(lo)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
|
|
svg.append(f'<line x1="{sx(hi)}" y1="{pad}" x2="{sx(hi)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
|
|
|
|
svg.append("</svg>")
|
|
out_path.write_text("\n".join(svg), encoding="utf-8")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_dir = Path(__file__).resolve().parent
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
cont_cols = load_split(base_dir)
|
|
ref_glob = resolve_reference_glob(base_dir, args.reference)
|
|
ref_files = sorted(glob(ref_glob))
|
|
if not ref_files:
|
|
raise SystemExit(f"no reference files matched: {ref_glob}")
|
|
|
|
gen_path = Path(args.generated)
|
|
gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False)
|
|
|
|
# Reference values (aggregate across files)
|
|
real_vals = {c: [] for c in cont_cols}
|
|
total_rows = 0
|
|
for f in ref_files:
|
|
vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True)
|
|
total_rows += rows
|
|
for c in cont_cols:
|
|
real_vals[c].extend(vals[c])
|
|
|
|
# KS per feature
|
|
ks_rows = []
|
|
for c in cont_cols:
|
|
ks = ks_statistic(real_vals[c], gen_vals[c])
|
|
# boundary pile-up (using real min/max)
|
|
if real_vals[c]:
|
|
lo = min(real_vals[c])
|
|
hi = max(real_vals[c])
|
|
else:
|
|
lo = hi = 0.0
|
|
tol = (hi - lo) * 1e-4 if hi > lo else 1e-6
|
|
gen = gen_vals[c]
|
|
if gen:
|
|
frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen)
|
|
frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen)
|
|
else:
|
|
frac_lo = frac_hi = 0.0
|
|
ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi))
|
|
|
|
ks_rows.sort(key=lambda x: x[1], reverse=True)
|
|
out_csv = out_dir / "ks_per_feature.csv"
|
|
with out_csv.open("w", newline="") as fh:
|
|
w = csv.writer(fh)
|
|
w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"])
|
|
for row in ks_rows:
|
|
w.writerow(row)
|
|
|
|
# top-k CDF plots
|
|
for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]:
|
|
out_svg = out_dir / f"cdf_{c}.svg"
|
|
render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi))
|
|
|
|
# summary
|
|
summary = {
|
|
"generated_rows": gen_rows,
|
|
"reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full",
|
|
"stride": args.stride,
|
|
"top_k_features": [r[0] for r in ks_rows[: args.top_k]],
|
|
}
|
|
(out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
|
|
print(f"wrote {out_csv}")
|
|
print(f"wrote CDF svgs for top {args.top_k} features")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|