优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
@@ -75,15 +75,16 @@ def ks_statistic(x: List[float], y: List[float]) -> float:
|
||||
n = len(x_sorted)
|
||||
m = len(y_sorted)
|
||||
i = j = 0
|
||||
cdf_x = cdf_y = 0.0
|
||||
d = 0.0
|
||||
while i < n and j < m:
|
||||
if x_sorted[i] <= y_sorted[j]:
|
||||
# Iterate over merged unique values to handle ties correctly
|
||||
merged = sorted(set(x_sorted) | set(y_sorted))
|
||||
for v in merged:
|
||||
while i < n and x_sorted[i] <= v:
|
||||
i += 1
|
||||
cdf_x = i / n
|
||||
else:
|
||||
while j < m and y_sorted[j] <= v:
|
||||
j += 1
|
||||
cdf_y = j / m
|
||||
cdf_x = i / n
|
||||
cdf_y = j / m
|
||||
d = max(d, abs(cdf_x - cdf_y))
|
||||
return d
|
||||
|
||||
@@ -103,31 +104,43 @@ def lag1_corr(values: List[float]) -> float:
|
||||
return num / math.sqrt(den_x * den_y)
|
||||
|
||||
|
||||
def resolve_reference_path(path: str) -> Optional[str]:
|
||||
def resolve_reference_paths(path: str) -> List[str]:
|
||||
if not path:
|
||||
return None
|
||||
return []
|
||||
if any(ch in path for ch in ["*", "?", "["]):
|
||||
base = Path(path).parent
|
||||
base = Path(path).parent.resolve()
|
||||
pat = Path(path).name
|
||||
matches = sorted(base.glob(pat))
|
||||
return str(matches[0]) if matches else None
|
||||
return str(path)
|
||||
return [str(p) for p in matches]
|
||||
return [str(path)]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
args.generated = str((base_dir / args.generated).resolve()) if not Path(args.generated).is_absolute() else args.generated
|
||||
args.split = str((base_dir / args.split).resolve()) if not Path(args.split).is_absolute() else args.split
|
||||
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
|
||||
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
|
||||
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
|
||||
|
||||
def resolve_file(p: str) -> str:
|
||||
path = Path(p)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
if path.exists():
|
||||
return str(path.resolve())
|
||||
candidate = base_dir / path
|
||||
if candidate.exists():
|
||||
return str(candidate.resolve())
|
||||
return str((base_dir / path).resolve())
|
||||
|
||||
args.generated = resolve_file(args.generated)
|
||||
args.split = resolve_file(args.split)
|
||||
args.stats = resolve_file(args.stats)
|
||||
args.vocab = resolve_file(args.vocab)
|
||||
args.out = resolve_file(args.out)
|
||||
if args.reference and not Path(args.reference).is_absolute():
|
||||
if any(ch in args.reference for ch in ["*", "?", "["]):
|
||||
args.reference = str(base_dir / args.reference)
|
||||
else:
|
||||
args.reference = str((base_dir / args.reference).resolve())
|
||||
ref_path = resolve_reference_path(args.reference)
|
||||
ref_paths = resolve_reference_paths(args.reference)
|
||||
split = load_json(args.split)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
@@ -156,7 +169,7 @@ def main():
|
||||
except Exception:
|
||||
v = 0.0
|
||||
update_stats(cont_stats, c, v)
|
||||
if ref_path:
|
||||
if ref_paths:
|
||||
pass
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
@@ -182,7 +195,7 @@ def main():
|
||||
}
|
||||
|
||||
# Optional richer metrics using reference data
|
||||
if ref_path:
|
||||
if ref_paths:
|
||||
ref_cont = {c: [] for c in cont_cols}
|
||||
ref_disc = {c: {} for c in disc_cols}
|
||||
gen_cont = {c: [] for c in cont_cols}
|
||||
@@ -202,21 +215,26 @@ def main():
|
||||
tok = row[c]
|
||||
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
||||
|
||||
with open_csv(ref_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
ref_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
if args.max_rows and i + 1 >= args.max_rows:
|
||||
break
|
||||
loaded = 0
|
||||
for ref_path in ref_paths:
|
||||
with open_csv(ref_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
ref_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
loaded += 1
|
||||
if args.max_rows and loaded >= args.max_rows:
|
||||
break
|
||||
if args.max_rows and loaded >= args.max_rows:
|
||||
break
|
||||
|
||||
# Continuous metrics: KS + quantiles + lag1 correlation
|
||||
cont_ks = {}
|
||||
|
||||
Reference in New Issue
Block a user