Files
mask-ddpm/example/diagnose_ks.py
2026-01-27 18:22:25 +08:00

231 lines
7.7 KiB
Python

#!/usr/bin/env python3
"""Per-feature KS diagnostics and CDF visualization (no third-party deps)."""
import argparse
import csv
import gzip
import json
import math
from pathlib import Path
from glob import glob
def parse_args():
parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--reference", default=str(base_dir / "config.json"))
parser.add_argument("--out-dir", default=str(base_dir / "results"))
parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan")
parser.add_argument("--stride", type=int, default=1, help="row stride sampling")
parser.add_argument("--top-k", type=int, default=8)
return parser.parse_args()
def load_split(base_dir: Path):
with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f:
split = json.load(f)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
return cont_cols
def resolve_reference_glob(base_dir: Path, ref_arg: str):
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
# On Windows, Path.resolve fails on glob patterns like *.csv.gz
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True):
values = {c: [] for c in cols}
row_count = 0
reader = None
if gz:
fh = gzip.open(path, "rt", newline="")
else:
fh = open(path, "r", newline="", encoding="utf-8")
try:
reader = csv.DictReader(fh)
for i, row in enumerate(reader):
if stride > 1 and i % stride != 0:
continue
for c in cols:
v = row.get(c, "")
try:
fv = float(v)
if math.isfinite(fv):
values[c].append(fv)
except Exception:
continue
row_count += 1
if max_rows > 0 and row_count >= max_rows:
break
finally:
fh.close()
return values, row_count
def ks_statistic(a, b):
if not a or not b:
return 1.0
a = sorted(a)
b = sorted(b)
na = len(a)
nb = len(b)
i = j = 0
d = 0.0
while i < na and j < nb:
if a[i] <= b[j]:
i += 1
else:
j += 1
fa = i / na
fb = j / nb
d = max(d, abs(fa - fb))
return d
def ecdf_points(vals):
vals = sorted(vals)
n = len(vals)
if n == 0:
return [], []
xs = []
ys = []
last = None
for i, v in enumerate(vals, 1):
if last is None or v != last:
xs.append(v)
ys.append(i / n)
last = v
else:
ys[-1] = i / n
return xs, ys
def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None):
width, height = 900, 420
pad = 50
panel_w = width - pad * 2
panel_h = height - pad * 2
if not real_vals or not gen_vals:
return
min_v = min(min(real_vals), min(gen_vals))
max_v = max(max(real_vals), max(gen_vals))
if max_v == min_v:
max_v += 1.0
rx, ry = ecdf_points(real_vals)
gx, gy = ecdf_points(gen_vals)
def sx(v):
return pad + int((v - min_v) * panel_w / (max_v - min_v))
def sy(v):
return pad + panel_h - int(v * panel_h)
svg = []
svg.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">')
svg.append('<style>text{font-family:Arial,sans-serif;font-size:12px}</style>')
svg.append(f'<text x="{pad}" y="{pad-20}">CDF 비교: {feature}</text>')
svg.append(f'<line x1="{pad}" y1="{pad}" x2="{pad}" y2="{pad+panel_h}" stroke="#333"/>')
svg.append(f'<line x1="{pad}" y1="{pad+panel_h}" x2="{pad+panel_w}" y2="{pad+panel_h}" stroke="#333"/>')
def path_from(xs, ys, color):
pts = " ".join(f"{sx(x)},{sy(y)}" for x, y in zip(xs, ys))
return f'<polyline fill="none" stroke="{color}" stroke-width="2" points="{pts}"/>'
svg.append(path_from(rx, ry, "#1f77b4")) # real
svg.append(path_from(gx, gy, "#d62728")) # gen
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+15}" fill="#1f77b4">real</text>')
svg.append(f'<text x="{pad+panel_w-120}" y="{pad+30}" fill="#d62728">generated</text>')
if bounds is not None:
lo, hi = bounds
svg.append(f'<line x1="{sx(lo)}" y1="{pad}" x2="{sx(lo)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
svg.append(f'<line x1="{sx(hi)}" y1="{pad}" x2="{sx(hi)}" y2="{pad+panel_h}" stroke="#999" stroke-dasharray="4 3"/>')
svg.append("</svg>")
out_path.write_text("\n".join(svg), encoding="utf-8")
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
cont_cols = load_split(base_dir)
ref_glob = resolve_reference_glob(base_dir, args.reference)
ref_files = sorted(glob(ref_glob))
if not ref_files:
raise SystemExit(f"no reference files matched: {ref_glob}")
gen_path = Path(args.generated)
gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False)
# Reference values (aggregate across files)
real_vals = {c: [] for c in cont_cols}
total_rows = 0
for f in ref_files:
vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True)
total_rows += rows
for c in cont_cols:
real_vals[c].extend(vals[c])
# KS per feature
ks_rows = []
for c in cont_cols:
ks = ks_statistic(real_vals[c], gen_vals[c])
# boundary pile-up (using real min/max)
if real_vals[c]:
lo = min(real_vals[c])
hi = max(real_vals[c])
else:
lo = hi = 0.0
tol = (hi - lo) * 1e-4 if hi > lo else 1e-6
gen = gen_vals[c]
if gen:
frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen)
frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen)
else:
frac_lo = frac_hi = 0.0
ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi))
ks_rows.sort(key=lambda x: x[1], reverse=True)
out_csv = out_dir / "ks_per_feature.csv"
with out_csv.open("w", newline="") as fh:
w = csv.writer(fh)
w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"])
for row in ks_rows:
w.writerow(row)
# top-k CDF plots
for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]:
out_svg = out_dir / f"cdf_{c}.svg"
render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi))
# summary
summary = {
"generated_rows": gen_rows,
"reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full",
"stride": args.stride,
"top_k_features": [r[0] for r in ks_rows[: args.top_k]],
}
(out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(f"wrote {out_csv}")
print(f"wrote CDF svgs for top {args.top_k} features")
if __name__ == "__main__":
main()