Add full quantile stats and post-hoc calibration
This commit is contained in:
@@ -153,12 +153,15 @@ def compute_cont_stats(
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
||||
raw_quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in cont_cols:
|
||||
raw_val = row[c]
|
||||
if raw_val is None or raw_val == "":
|
||||
continue
|
||||
x = float(raw_val)
|
||||
if raw_quantile_values is not None:
|
||||
raw_quantile_values[c].append(x)
|
||||
if transforms.get(c) == "log1p":
|
||||
if x < 0:
|
||||
x = 0.0
|
||||
@@ -184,22 +187,36 @@ def compute_cont_stats(
|
||||
|
||||
quantile_probs = None
|
||||
quantile_table = None
|
||||
raw_quantile_table = None
|
||||
if quantile_values is not None:
|
||||
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
|
||||
quantile_table = {}
|
||||
raw_quantile_table = {}
|
||||
for c in cont_cols:
|
||||
vals = quantile_values[c]
|
||||
if not vals:
|
||||
quantile_table[c] = [0.0 for _ in quantile_probs]
|
||||
else:
|
||||
vals.sort()
|
||||
n = len(vals)
|
||||
qvals = []
|
||||
for p in quantile_probs:
|
||||
idx = int(round(p * (n - 1)))
|
||||
idx = max(0, min(n - 1, idx))
|
||||
qvals.append(float(vals[idx]))
|
||||
quantile_table[c] = qvals
|
||||
raw_vals = raw_quantile_values[c] if raw_quantile_values is not None else []
|
||||
if not raw_vals:
|
||||
raw_quantile_table[c] = [0.0 for _ in quantile_probs]
|
||||
continue
|
||||
vals.sort()
|
||||
n = len(vals)
|
||||
qvals = []
|
||||
raw_vals.sort()
|
||||
n = len(raw_vals)
|
||||
rqvals = []
|
||||
for p in quantile_probs:
|
||||
idx = int(round(p * (n - 1)))
|
||||
idx = max(0, min(n - 1, idx))
|
||||
qvals.append(float(vals[idx]))
|
||||
quantile_table[c] = qvals
|
||||
rqvals.append(float(raw_vals[idx]))
|
||||
raw_quantile_table[c] = rqvals
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
@@ -216,6 +233,7 @@ def compute_cont_stats(
|
||||
"max_rows": max_rows,
|
||||
"quantile_probs": quantile_probs,
|
||||
"quantile_values": quantile_table,
|
||||
"quantile_raw_values": raw_quantile_table,
|
||||
}
|
||||
|
||||
|
||||
@@ -344,6 +362,35 @@ def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
||||
return x
|
||||
|
||||
|
||||
def quantile_calibrate_to_real(x, cont_cols, quantile_probs, real_quantile_values):
|
||||
import torch
|
||||
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
||||
flat = x.reshape(-1, x.size(-1))
|
||||
for i, c in enumerate(cont_cols):
|
||||
v = flat[:, i]
|
||||
gen_q = torch.quantile(v, probs_t)
|
||||
idx = torch.bucketize(v, gen_q)
|
||||
idx = torch.clamp(idx, 1, gen_q.numel() - 1)
|
||||
x0 = gen_q[idx - 1]
|
||||
x1 = gen_q[idx]
|
||||
p0 = probs_t[idx - 1]
|
||||
p1 = probs_t[idx]
|
||||
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
|
||||
p = p0 + (v - x0) * (p1 - p0) / denom
|
||||
|
||||
real_q = torch.tensor(real_quantile_values[c], dtype=x.dtype, device=x.device)
|
||||
idx2 = torch.bucketize(p, probs_t)
|
||||
idx2 = torch.clamp(idx2, 1, probs_t.numel() - 1)
|
||||
rp0 = probs_t[idx2 - 1]
|
||||
rp1 = probs_t[idx2]
|
||||
r0 = real_q[idx2 - 1]
|
||||
r1 = real_q[idx2]
|
||||
denom2 = torch.where((rp1 - rp0) == 0, torch.ones_like(rp1 - rp0), (rp1 - rp0))
|
||||
v2 = r0 + (p - rp0) * (r1 - r0) / denom2
|
||||
flat[:, i] = v2
|
||||
return flat.reshape(x.shape)
|
||||
|
||||
|
||||
def windowed_batches(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
|
||||
Reference in New Issue
Block a user