update
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,18 @@ def load_split(path: str) -> Dict[str, List[str]]:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def iter_rows(path: str) -> Iterable[Dict[str, str]]:
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
|
||||
for path in paths:
|
||||
opener = gzip.open if str(path).endswith(".gz") else open
|
||||
with opener(path, "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||
@@ -52,7 +55,7 @@ def compute_cont_stats(
|
||||
|
||||
|
||||
def build_vocab(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
disc_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
@@ -80,7 +83,7 @@ def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[st
|
||||
|
||||
|
||||
def windowed_batches(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
disc_cols: List[str],
|
||||
vocab: Dict[str, Dict[str, int]],
|
||||
@@ -89,10 +92,12 @@ def windowed_batches(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
):
|
||||
import torch
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
@@ -105,22 +110,34 @@ def windowed_batches(
|
||||
seq_disc = []
|
||||
|
||||
batches_yielded = 0
|
||||
for row in iter_rows(path):
|
||||
cont_row = [float(row[c]) for c in cont_cols]
|
||||
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
paths = [path] if isinstance(path, str) else list(path)
|
||||
for file_id, p in enumerate(paths):
|
||||
for row in iter_rows(p):
|
||||
cont_row = [float(row[c]) for c in cont_cols]
|
||||
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
else:
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
# drop partial sequence at file boundary
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
# Drop last partial batch for simplicity
|
||||
|
||||
Reference in New Issue
Block a user