连续型特征在时许相关性上的不足
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||
yield row
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
def _stream_basic_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
|
||||
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
|
||||
count = 0
|
||||
):
|
||||
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
m3 = {c: 0.0 for c in cont_cols}
|
||||
vmin = {c: float("inf") for c in cont_cols}
|
||||
vmax = {c: float("-inf") for c in cont_cols}
|
||||
int_like = {c: True for c in cont_cols}
|
||||
max_decimals = {c: 0 for c in cont_cols}
|
||||
all_pos = {c: True for c in cont_cols}
|
||||
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
count += 1
|
||||
for c in cont_cols:
|
||||
raw = row[c]
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
x = float(raw)
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / count
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
if x <= 0:
|
||||
all_pos[c] = False
|
||||
if x < vmin[c]:
|
||||
vmin[c] = x
|
||||
if x > vmax[c]:
|
||||
vmax[c] = x
|
||||
if int_like[c] and abs(x - round(x)) > 1e-9:
|
||||
int_like[c] = False
|
||||
# track decimal places from raw string if possible
|
||||
if "e" in raw or "E" in raw:
|
||||
# scientific notation, skip precision inference
|
||||
continue
|
||||
if "." in raw:
|
||||
if "e" not in raw and "E" not in raw and "." in raw:
|
||||
dec = raw.split(".", 1)[1].rstrip("0")
|
||||
if len(dec) > max_decimals[c]:
|
||||
max_decimals[c] = len(dec)
|
||||
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
delta_n = delta / n
|
||||
term1 = delta * delta_n * (n - 1)
|
||||
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
|
||||
m2[c] += term1
|
||||
mean[c] += delta_n
|
||||
count[c] = n
|
||||
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
# finalize std/skew
|
||||
std = {}
|
||||
skew = {}
|
||||
for c in cont_cols:
|
||||
if count > 1:
|
||||
var = m2[c] / (count - 1)
|
||||
n = count[c]
|
||||
if n > 1:
|
||||
var = m2[c] / (n - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
# replace infs if column had no valid values
|
||||
if n > 2 and m2[c] > 0:
|
||||
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
|
||||
else:
|
||||
skew[c] = 0.0
|
||||
|
||||
for c in cont_cols:
|
||||
if vmin[c] == float("inf"):
|
||||
vmin[c] = 0.0
|
||||
if vmax[c] == float("-inf"):
|
||||
vmax[c] = 0.0
|
||||
return mean, std, vmin, vmax, int_like, max_decimals
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"m2": m2,
|
||||
"m3": m3,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"skew": skew,
|
||||
"all_pos": all_pos,
|
||||
}
|
||||
|
||||
|
||||
def choose_cont_transforms(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
skew_threshold: float = 1.5,
|
||||
range_ratio_threshold: float = 1e3,
|
||||
):
|
||||
"""Pick per-column transform (currently log1p or none) based on skew/range."""
|
||||
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
transforms = {}
|
||||
for c in cont_cols:
|
||||
if not stats["all_pos"][c]:
|
||||
transforms[c] = "none"
|
||||
continue
|
||||
skew = abs(stats["skew"][c])
|
||||
vmin = stats["min"][c]
|
||||
vmax = stats["max"][c]
|
||||
ratio = (vmax / vmin) if vmin > 0 else 0.0
|
||||
if skew >= skew_threshold or ratio >= range_ratio_threshold:
|
||||
transforms[c] = "log1p"
|
||||
else:
|
||||
transforms[c] = "none"
|
||||
return transforms, stats
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
|
||||
# First pass (raw) for metadata and raw mean/std
|
||||
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
|
||||
# Optional transform selection
|
||||
if transforms is None:
|
||||
transforms = {c: "none" for c in cont_cols}
|
||||
|
||||
# Second pass for transformed mean/std
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in cont_cols:
|
||||
raw_val = row[c]
|
||||
if raw_val is None or raw_val == "":
|
||||
continue
|
||||
x = float(raw_val)
|
||||
if transforms.get(c) == "log1p":
|
||||
if x < 0:
|
||||
x = 0.0
|
||||
x = math.log1p(x)
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / n
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
count[c] = n
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
std = {}
|
||||
for c in cont_cols:
|
||||
if count[c] > 1:
|
||||
var = m2[c] / (count[c] - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"raw_mean": raw["mean"],
|
||||
"raw_std": raw["std"],
|
||||
"min": raw["min"],
|
||||
"max": raw["max"],
|
||||
"int_like": raw["int_like"],
|
||||
"max_decimals": raw["max_decimals"],
|
||||
"transform": transforms,
|
||||
"skew": raw["skew"],
|
||||
"all_pos": raw["all_pos"],
|
||||
"max_rows": max_rows,
|
||||
}
|
||||
|
||||
|
||||
def build_vocab(
|
||||
@@ -130,8 +243,19 @@ def build_disc_stats(
|
||||
return vocab, top_token
|
||||
|
||||
|
||||
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||
def normalize_cont(
|
||||
x,
|
||||
cont_cols: List[str],
|
||||
mean: Dict[str, float],
|
||||
std: Dict[str, float],
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
import torch
|
||||
|
||||
if transforms:
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
return (x - mean_t) / std_t
|
||||
@@ -148,19 +272,34 @@ def windowed_batches(
|
||||
seq_len: int,
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
shuffle_buffer: int = 0,
|
||||
):
|
||||
import torch
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
buffer = []
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
def flush_seq():
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
|
||||
def flush_seq(file_id: int):
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
|
||||
if len(seq_cont) == seq_len:
|
||||
batch_cont.append(seq_cont)
|
||||
batch_disc.append(seq_disc)
|
||||
if shuffle_buffer and shuffle_buffer > 0:
|
||||
buffer.append((list(seq_cont), list(seq_disc), file_id))
|
||||
if len(buffer) >= shuffle_buffer:
|
||||
idx = random.randrange(len(buffer))
|
||||
seq_c, seq_d, seq_f = buffer.pop(idx)
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
else:
|
||||
batch_cont.append(seq_cont)
|
||||
batch_disc.append(seq_disc)
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
@@ -173,13 +312,11 @@ def windowed_batches(
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
flush_seq(file_id)
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
@@ -195,4 +332,29 @@ def windowed_batches(
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
# Flush any remaining buffered sequences
|
||||
if shuffle_buffer and buffer:
|
||||
random.shuffle(buffer)
|
||||
for seq_c, seq_d, seq_f in buffer:
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
if len(batch_cont) == batch_size:
|
||||
import torch
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
else:
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
|
||||
# Drop last partial batch for simplicity
|
||||
|
||||
Reference in New Issue
Block a user