优化6个类,现在ks降低到0.28,史称3.0版本

This commit is contained in:
2026-01-28 20:10:42 +08:00
parent 59697c0640
commit 39eede92f6
28 changed files with 3317 additions and 225 deletions

View File

@@ -75,15 +75,16 @@ def ks_statistic(x: List[float], y: List[float]) -> float:
n = len(x_sorted)
m = len(y_sorted)
i = j = 0
cdf_x = cdf_y = 0.0
d = 0.0
while i < n and j < m:
if x_sorted[i] <= y_sorted[j]:
# Iterate over merged unique values to handle ties correctly
merged = sorted(set(x_sorted) | set(y_sorted))
for v in merged:
while i < n and x_sorted[i] <= v:
i += 1
cdf_x = i / n
else:
while j < m and y_sorted[j] <= v:
j += 1
cdf_y = j / m
cdf_x = i / n
cdf_y = j / m
d = max(d, abs(cdf_x - cdf_y))
return d
@@ -103,31 +104,43 @@ def lag1_corr(values: List[float]) -> float:
return num / math.sqrt(den_x * den_y)
def resolve_reference_path(path: str) -> Optional[str]:
def resolve_reference_paths(path: str) -> List[str]:
if not path:
return None
return []
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent
base = Path(path).parent.resolve()
pat = Path(path).name
matches = sorted(base.glob(pat))
return str(matches[0]) if matches else None
return str(path)
return [str(p) for p in matches]
return [str(path)]
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated
args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
def resolve_file(p: str) -> str:
path = Path(p)
if path.is_absolute():
return str(path)
if path.exists():
return str(path.resolve())
candidate = base_dir / path
if candidate.exists():
return str(candidate.resolve())
return str((base_dir / path).resolve())
args.generated = resolve_file(args.generated)
args.split = resolve_file(args.split)
args.stats = resolve_file(args.stats)
args.vocab = resolve_file(args.vocab)
args.out = resolve_file(args.out)
if args.reference and not Path(args.reference).is_absolute():
if any(ch in args.reference for ch in ["*", "?", "["]):
args.reference = str(base_dir / args.reference)
else:
args.reference = str((base_dir / args.reference).resolve())
ref_path = resolve_reference_path(args.reference)
ref_paths = resolve_reference_paths(args.reference)
split = load_json(args.split)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
@@ -156,7 +169,7 @@ def main():
except Exception:
v = 0.0
update_stats(cont_stats, c, v)
if ref_path:
if ref_paths:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
@@ -182,7 +195,7 @@ def main():
}
# Optional richer metrics using reference data
if ref_path:
if ref_paths:
ref_cont = {c: [] for c in cont_cols}
ref_disc = {c: {} for c in disc_cols}
gen_cont = {c: [] for c in cont_cols}
@@ -202,21 +215,26 @@ def main():
tok = row[c]
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
if args.max_rows and i + 1 >= args.max_rows:
break
loaded = 0
for ref_path in ref_paths:
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for row in reader:
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
loaded += 1
if args.max_rows and loaded >= args.max_rows:
break
if args.max_rows and loaded >= args.max_rows:
break
# Continuous metrics: KS + quantiles + lag1 correlation
cont_ks = {}