连续型特征在时许相关性上的不足

This commit is contained in:
2026-01-23 15:06:52 +08:00
parent 0d17be9a1c
commit ff12324560
12 changed files with 1212 additions and 68 deletions

View File

@@ -4,6 +4,8 @@
import csv
import gzip
import json
import math
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
yield row
def compute_cont_stats(
def _stream_basic_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
count = 0
):
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
m3 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
all_pos = {c: True for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols:
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
delta = x - mean[c]
mean[c] += delta / count
delta2 = x - mean[c]
m2[c] += delta * delta2
if x <= 0:
all_pos[c] = False
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
# track decimal places from raw string if possible
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
if "e" not in raw and "E" not in raw and "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
n = count[c] + 1
delta = x - mean[c]
delta_n = delta / n
term1 = delta * delta_n * (n - 1)
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
m2[c] += term1
mean[c] += delta_n
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
# finalize std/skew
std = {}
skew = {}
for c in cont_cols:
if count > 1:
var = m2[c] / (count - 1)
n = count[c]
if n > 1:
var = m2[c] / (n - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
# replace infs if column had no valid values
if n > 2 and m2[c] > 0:
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
else:
skew[c] = 0.0
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
return {
"count": count,
"mean": mean,
"std": std,
"m2": m2,
"m3": m3,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"skew": skew,
"all_pos": all_pos,
}
def choose_cont_transforms(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
skew_threshold: float = 1.5,
range_ratio_threshold: float = 1e3,
):
"""Pick per-column transform (currently log1p or none) based on skew/range."""
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
transforms = {}
for c in cont_cols:
if not stats["all_pos"][c]:
transforms[c] = "none"
continue
skew = abs(stats["skew"][c])
vmin = stats["min"][c]
vmax = stats["max"][c]
ratio = (vmax / vmin) if vmin > 0 else 0.0
if skew >= skew_threshold or ratio >= range_ratio_threshold:
transforms[c] = "log1p"
else:
transforms[c] = "none"
return transforms, stats
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
# Optional transform selection
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
if raw_val is None or raw_val == "":
continue
x = float(raw_val)
if transforms.get(c) == "log1p":
if x < 0:
x = 0.0
x = math.log1p(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
delta2 = x - mean[c]
m2[c] += delta * delta2
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count[c] > 1:
var = m2[c] / (count[c] - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return {
"mean": mean,
"std": std,
"raw_mean": raw["mean"],
"raw_std": raw["std"],
"min": raw["min"],
"max": raw["max"],
"int_like": raw["int_like"],
"max_decimals": raw["max_decimals"],
"transform": transforms,
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
}
def build_vocab(
@@ -130,8 +243,19 @@ def build_disc_stats(
return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
def normalize_cont(
x,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
):
import torch
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
@@ -148,19 +272,34 @@ def windowed_batches(
seq_len: int,
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
shuffle_buffer: int = 0,
):
import torch
batch_cont = []
batch_disc = []
batch_file = []
buffer = []
seq_cont = []
seq_disc = []
def flush_seq():
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
def flush_seq(file_id: int):
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
if len(seq_cont) == seq_len:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if shuffle_buffer and shuffle_buffer > 0:
buffer.append((list(seq_cont), list(seq_disc), file_id))
if len(buffer) >= shuffle_buffer:
idx = random.randrange(len(buffer))
seq_c, seq_d, seq_f = buffer.pop(idx)
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
else:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if return_file_id:
batch_file.append(file_id)
seq_cont = []
seq_disc = []
@@ -173,13 +312,11 @@ def windowed_batches(
seq_cont.append(cont_row)
seq_disc.append(disc_row)
if len(seq_cont) == seq_len:
flush_seq()
if return_file_id:
batch_file.append(file_id)
flush_seq(file_id)
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
@@ -195,4 +332,29 @@ def windowed_batches(
seq_cont = []
seq_disc = []
# Flush any remaining buffered sequences
if shuffle_buffer and buffer:
random.shuffle(buffer)
for seq_c, seq_d, seq_f in buffer:
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
if len(batch_cont) == batch_size:
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# Drop last partial batch for simplicity