diff --git a/example/README.md b/example/README.md index 32ffe81..b88330f 100644 --- a/example/README.md +++ b/example/README.md @@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto - Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training. - Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. - Export now clamps continuous features to training min/max and preserves integer/decimal precision. +- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json). - `` tokens are replaced by the most frequent token for each discrete column at export. - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. diff --git a/example/config.json b/example/config.json index d5feaf6..58baa92 100644 --- a/example/config.json +++ b/example/config.json @@ -6,11 +6,11 @@ "vocab_path": "./results/disc_vocab.json", "out_dir": "./results", "device": "auto", - "timesteps": 400, + "timesteps": 600, "batch_size": 128, "seq_len": 128, - "epochs": 8, - "max_batches": 3000, + "epochs": 10, + "max_batches": 4000, "lambda": 0.5, "lr": 0.0005, "seed": 1337, @@ -23,8 +23,17 @@ "use_condition": true, "condition_type": "file_id", "cond_dim": 32, - "use_tanh_eps": true, + "use_tanh_eps": false, "eps_scale": 1.0, + "model_time_dim": 128, + "model_hidden_dim": 512, + "model_num_layers": 2, + "model_dropout": 0.1, + "model_ff_mult": 2, + "model_pos_dim": 64, + "model_use_pos_embed": true, + "disc_mask_scale": 0.9, + "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/data_utils.py b/example/data_utils.py index a38b221..aa36195 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -4,6 +4,8 @@ import csv import gzip import json +import math +import random from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]: yield row -def compute_cont_stats( +def _stream_basic_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, -) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]: - """Streaming mean/std (Welford) + min/max + int/precision metadata.""" - count = 0 +): + """Streaming stats with mean/M2/M3 + min/max + int/precision metadata.""" + count = {c: 0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} + m3 = {c: 0.0 for c in cont_cols} vmin = {c: float("inf") for c in cont_cols} vmax = {c: float("-inf") for c in cont_cols} int_like = {c: True for c in cont_cols} max_decimals = {c: 0 for c in cont_cols} + all_pos = {c: True for c in cont_cols} for i, row in enumerate(iter_rows(path)): - count += 1 for c in cont_cols: raw = row[c] if raw is None or raw == "": continue x = float(raw) - delta = x - mean[c] - mean[c] += delta / count - delta2 = x - mean[c] - m2[c] += delta * delta2 + if x <= 0: + all_pos[c] = False if x < vmin[c]: vmin[c] = x if x > vmax[c]: vmax[c] = x if int_like[c] and abs(x - round(x)) > 1e-9: int_like[c] = False - # track decimal places from raw string if possible - if "e" in raw or "E" in raw: - # scientific notation, skip precision inference - continue - if "." in raw: + if "e" not in raw and "E" not in raw and "." in raw: dec = raw.split(".", 1)[1].rstrip("0") if len(dec) > max_decimals[c]: max_decimals[c] = len(dec) + + n = count[c] + 1 + delta = x - mean[c] + delta_n = delta / n + term1 = delta * delta_n * (n - 1) + m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c] + m2[c] += term1 + mean[c] += delta_n + count[c] = n + if max_rows is not None and i + 1 >= max_rows: break + # finalize std/skew std = {} + skew = {} for c in cont_cols: - if count > 1: - var = m2[c] / (count - 1) + n = count[c] + if n > 1: + var = m2[c] / (n - 1) else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 - # replace infs if column had no valid values + if n > 2 and m2[c] > 0: + skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5) + else: + skew[c] = 0.0 + for c in cont_cols: if vmin[c] == float("inf"): vmin[c] = 0.0 if vmax[c] == float("-inf"): vmax[c] = 0.0 - return mean, std, vmin, vmax, int_like, max_decimals + + return { + "count": count, + "mean": mean, + "std": std, + "m2": m2, + "m3": m3, + "min": vmin, + "max": vmax, + "int_like": int_like, + "max_decimals": max_decimals, + "skew": skew, + "all_pos": all_pos, + } + + +def choose_cont_transforms( + path: Union[str, List[str]], + cont_cols: List[str], + max_rows: Optional[int] = None, + skew_threshold: float = 1.5, + range_ratio_threshold: float = 1e3, +): + """Pick per-column transform (currently log1p or none) based on skew/range.""" + stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows) + transforms = {} + for c in cont_cols: + if not stats["all_pos"][c]: + transforms[c] = "none" + continue + skew = abs(stats["skew"][c]) + vmin = stats["min"][c] + vmax = stats["max"][c] + ratio = (vmax / vmin) if vmin > 0 else 0.0 + if skew >= skew_threshold or ratio >= range_ratio_threshold: + transforms[c] = "log1p" + else: + transforms[c] = "none" + return transforms, stats + + +def compute_cont_stats( + path: Union[str, List[str]], + cont_cols: List[str], + max_rows: Optional[int] = None, + transforms: Optional[Dict[str, str]] = None, +): + """Compute stats on (optionally transformed) values. Returns raw + transformed stats.""" + # First pass (raw) for metadata and raw mean/std + raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows) + + # Optional transform selection + if transforms is None: + transforms = {c: "none" for c in cont_cols} + + # Second pass for transformed mean/std + count = {c: 0 for c in cont_cols} + mean = {c: 0.0 for c in cont_cols} + m2 = {c: 0.0 for c in cont_cols} + for i, row in enumerate(iter_rows(path)): + for c in cont_cols: + raw_val = row[c] + if raw_val is None or raw_val == "": + continue + x = float(raw_val) + if transforms.get(c) == "log1p": + if x < 0: + x = 0.0 + x = math.log1p(x) + n = count[c] + 1 + delta = x - mean[c] + mean[c] += delta / n + delta2 = x - mean[c] + m2[c] += delta * delta2 + count[c] = n + if max_rows is not None and i + 1 >= max_rows: + break + + std = {} + for c in cont_cols: + if count[c] > 1: + var = m2[c] / (count[c] - 1) + else: + var = 0.0 + std[c] = var ** 0.5 if var > 0 else 1.0 + + return { + "mean": mean, + "std": std, + "raw_mean": raw["mean"], + "raw_std": raw["std"], + "min": raw["min"], + "max": raw["max"], + "int_like": raw["int_like"], + "max_decimals": raw["max_decimals"], + "transform": transforms, + "skew": raw["skew"], + "all_pos": raw["all_pos"], + "max_rows": max_rows, + } def build_vocab( @@ -130,8 +243,19 @@ def build_disc_stats( return vocab, top_token -def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]): +def normalize_cont( + x, + cont_cols: List[str], + mean: Dict[str, float], + std: Dict[str, float], + transforms: Optional[Dict[str, str]] = None, +): import torch + + if transforms: + for i, c in enumerate(cont_cols): + if transforms.get(c) == "log1p": + x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0)) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t @@ -148,19 +272,34 @@ def windowed_batches( seq_len: int, max_batches: Optional[int] = None, return_file_id: bool = False, + transforms: Optional[Dict[str, str]] = None, + shuffle_buffer: int = 0, ): import torch batch_cont = [] batch_disc = [] batch_file = [] + buffer = [] seq_cont = [] seq_disc = [] - def flush_seq(): - nonlocal seq_cont, seq_disc, batch_cont, batch_disc + def flush_seq(file_id: int): + nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file if len(seq_cont) == seq_len: - batch_cont.append(seq_cont) - batch_disc.append(seq_disc) + if shuffle_buffer and shuffle_buffer > 0: + buffer.append((list(seq_cont), list(seq_disc), file_id)) + if len(buffer) >= shuffle_buffer: + idx = random.randrange(len(buffer)) + seq_c, seq_d, seq_f = buffer.pop(idx) + batch_cont.append(seq_c) + batch_disc.append(seq_d) + if return_file_id: + batch_file.append(seq_f) + else: + batch_cont.append(seq_cont) + batch_disc.append(seq_disc) + if return_file_id: + batch_file.append(file_id) seq_cont = [] seq_disc = [] @@ -173,13 +312,11 @@ def windowed_batches( seq_cont.append(cont_row) seq_disc.append(disc_row) if len(seq_cont) == seq_len: - flush_seq() - if return_file_id: - batch_file.append(file_id) + flush_seq(file_id) if len(batch_cont) == batch_size: x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) - x_cont = normalize_cont(x_cont, cont_cols, mean, std) + x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file @@ -195,4 +332,29 @@ def windowed_batches( seq_cont = [] seq_disc = [] + # Flush any remaining buffered sequences + if shuffle_buffer and buffer: + random.shuffle(buffer) + for seq_c, seq_d, seq_f in buffer: + batch_cont.append(seq_c) + batch_disc.append(seq_d) + if return_file_id: + batch_file.append(seq_f) + if len(batch_cont) == batch_size: + import torch + x_cont = torch.tensor(batch_cont, dtype=torch.float32) + x_disc = torch.tensor(batch_disc, dtype=torch.long) + x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) + if return_file_id: + x_file = torch.tensor(batch_file, dtype=torch.long) + yield x_cont, x_disc, x_file + else: + yield x_cont, x_disc + batch_cont = [] + batch_disc = [] + batch_file = [] + batches_yielded += 1 + if max_batches is not None and batches_yielded >= max_batches: + return + # Drop last partial batch for simplicity diff --git a/example/evaluate_generated.py b/example/evaluate_generated.py index b5431e8..7ed87e6 100644 --- a/example/evaluate_generated.py +++ b/example/evaluate_generated.py @@ -5,8 +5,9 @@ import argparse import csv import gzip import json +import math from pathlib import Path -from typing import Dict, Tuple +from typing import Dict, Tuple, List, Optional def load_json(path: str) -> Dict: @@ -28,6 +29,8 @@ def parse_args(): parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--out", default=str(base_dir / "results" / "eval.json")) + parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics") + parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics") return parser.parse_args() @@ -55,6 +58,62 @@ def finalize_stats(stats): return out +def js_divergence(p, q, eps: float = 1e-12) -> float: + p = [max(x, eps) for x in p] + q = [max(x, eps) for x in q] + m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)] + def kl(a, b): + return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b)) + return 0.5 * kl(p, m) + 0.5 * kl(q, m) + + +def ks_statistic(x: List[float], y: List[float]) -> float: + if not x or not y: + return 0.0 + x_sorted = sorted(x) + y_sorted = sorted(y) + n = len(x_sorted) + m = len(y_sorted) + i = j = 0 + cdf_x = cdf_y = 0.0 + d = 0.0 + while i < n and j < m: + if x_sorted[i] <= y_sorted[j]: + i += 1 + cdf_x = i / n + else: + j += 1 + cdf_y = j / m + d = max(d, abs(cdf_x - cdf_y)) + return d + + +def lag1_corr(values: List[float]) -> float: + if len(values) < 3: + return 0.0 + x = values[:-1] + y = values[1:] + mean_x = sum(x) / len(x) + mean_y = sum(y) / len(y) + num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y)) + den_x = sum((xi - mean_x) ** 2 for xi in x) + den_y = sum((yi - mean_y) ** 2 for yi in y) + if den_x <= 0 or den_y <= 0: + return 0.0 + return num / math.sqrt(den_x * den_y) + + +def resolve_reference_path(path: str) -> Optional[str]: + if not path: + return None + if any(ch in path for ch in ["*", "?", "["]): + base = Path(path).parent + pat = Path(path).name + matches = sorted(base.glob(pat)) + return str(matches[0]) if matches else None + return str(path) + + def main(): args = parse_args() base_dir = Path(__file__).resolve().parent @@ -63,13 +122,18 @@ def main(): args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out + if args.reference and not Path(args.reference).is_absolute(): + args.reference = str((base_dir / args.reference).resolve()) + ref_path = resolve_reference_path(args.reference) split = load_json(args.split) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] - stats_ref = load_json(args.stats)["mean"] - std_ref = load_json(args.stats)["std"] + stats_json = load_json(args.stats) + stats_ref = stats_json.get("raw_mean", stats_json.get("mean")) + std_ref = stats_json.get("raw_std", stats_json.get("std")) + transforms = stats_json.get("transform", {}) vocab = load_json(args.vocab)["vocab"] vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols} @@ -89,6 +153,8 @@ def main(): except Exception: v = 0.0 update_stats(cont_stats, c, v) + if ref_path: + pass for c in disc_cols: if row[c] not in vocab_sets[c]: disc_invalid[c] += 1 @@ -112,6 +178,81 @@ def main(): "discrete_invalid_counts": disc_invalid, } + # Optional richer metrics using reference data + if ref_path: + ref_cont = {c: [] for c in cont_cols} + ref_disc = {c: {} for c in disc_cols} + gen_cont = {c: [] for c in cont_cols} + gen_disc = {c: {} for c in disc_cols} + + with open_csv(args.generated) as f: + reader = csv.DictReader(f) + for row in reader: + if time_col in row: + row.pop(time_col, None) + for c in cont_cols: + try: + gen_cont[c].append(float(row[c])) + except Exception: + gen_cont[c].append(0.0) + for c in disc_cols: + tok = row[c] + gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1 + + with open_csv(ref_path) as f: + reader = csv.DictReader(f) + for i, row in enumerate(reader): + if time_col in row: + row.pop(time_col, None) + for c in cont_cols: + try: + ref_cont[c].append(float(row[c])) + except Exception: + ref_cont[c].append(0.0) + for c in disc_cols: + tok = row[c] + ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1 + if args.max_rows and i + 1 >= args.max_rows: + break + + # Continuous metrics: KS + quantiles + lag1 correlation + cont_ks = {} + cont_quant = {} + cont_lag1 = {} + for c in cont_cols: + cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c]) + ref_sorted = sorted(ref_cont[c]) + gen_sorted = sorted(gen_cont[c]) + qs = [0.05, 0.25, 0.5, 0.75, 0.95] + def qval(arr, q): + if not arr: + return 0.0 + idx = int(q * (len(arr) - 1)) + return arr[idx] + cont_quant[c] = { + "q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)), + "q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)), + "q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)), + "q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)), + "q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)), + } + cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c])) + + # Discrete metrics: JSD over vocab + disc_jsd = {} + for c in disc_cols: + vocab_vals = list(vocab_sets[c]) + gen_total = sum(gen_disc[c].values()) or 1 + ref_total = sum(ref_disc[c].values()) or 1 + p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals] + q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals] + disc_jsd[c] = js_divergence(p, q) + + report["continuous_ks"] = cont_ks + report["continuous_quantile_diff"] = cont_quant + report["continuous_lag1_diff"] = cont_lag1 + report["discrete_jsd"] = disc_jsd + with open(args.out, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) diff --git a/example/export_samples.py b/example/export_samples.py index 3bfed15..9c46660 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -111,6 +111,7 @@ def main(): vmax = stats.get("max", {}) int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) + transforms = stats.get("transform", {}) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -141,6 +142,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=int(cfg.get("model_time_dim", 64)), + hidden_dim=int(cfg.get("model_hidden_dim", 256)), + num_layers=int(cfg.get("model_num_layers", 1)), + dropout=float(cfg.get("model_dropout", 0.0)), + ff_mult=int(cfg.get("model_ff_mult", 2)), + pos_dim=int(cfg.get("model_pos_dim", 64)), + use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size if use_condition else 0, cond_dim=int(cfg.get("cond_dim", 32)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), @@ -220,6 +228,9 @@ def main(): mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec + for i, c in enumerate(cont_cols): + if transforms.get(c) == "log1p": + x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) # clamp to observed min/max per feature if vmin and vmax: for i, c in enumerate(cont_cols): @@ -246,8 +257,8 @@ def main(): row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" if args.include_time and time_col in header: row[time_col] = str(row_index) - for i, c in enumerate(cont_cols): - val = float(x_cont[b, t, i]) + for i, c in enumerate(cont_cols): + val = float(x_cont[b, t, i]) if int_like.get(c, False): row[c] = str(int(round(val))) else: diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index 506a7ad..4f37356 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -35,11 +35,14 @@ def q_sample_discrete( t: torch.Tensor, mask_tokens: torch.Tensor, max_t: int, + mask_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly mask discrete tokens with a cosine schedule over t.""" bsz = x0.size(0) # cosine schedule: p(0)=0, p(max_t)=1 p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t))) + if mask_scale != 1.0: + p = torch.clamp(p * mask_scale, 0.0, 1.0) p = p.view(bsz, 1, 1) mask = torch.rand_like(x0.float()) < p x_masked = x0.clone() @@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module): disc_vocab_sizes: List[int], time_dim: int = 64, hidden_dim: int = 256, + num_layers: int = 1, + dropout: float = 0.0, + ff_mult: int = 2, + pos_dim: int = 64, + use_pos_embed: bool = True, cond_vocab_size: int = 0, cond_dim: int = 32, use_tanh_eps: bool = False, @@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module): self.time_embed = SinusoidalTimeEmbedding(time_dim) self.use_tanh_eps = use_tanh_eps self.eps_scale = eps_scale + self.pos_dim = pos_dim + self.use_pos_embed = use_pos_embed self.cond_vocab_size = cond_vocab_size self.cond_dim = cond_dim @@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module): disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) - in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0) + pos_dim = pos_dim if use_pos_embed else 0 + in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) self.in_proj = nn.Linear(in_dim, hidden_dim) - self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) + self.backbone = nn.GRU( + hidden_dim, + hidden_dim, + num_layers=num_layers, + dropout=dropout if num_layers > 1 else 0.0, + batch_first=True, + ) + self.post_norm = nn.LayerNorm(hidden_dim) + self.post_ff = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * ff_mult), + nn.GELU(), + nn.Linear(hidden_dim * ff_mult, hidden_dim), + ) self.cont_head = nn.Linear(hidden_dim, cont_dim) self.disc_heads = nn.ModuleList([ @@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module): """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" time_emb = self.time_embed(t) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) + pos_emb = None + if self.use_pos_embed and self.pos_dim > 0: + pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device) disc_embs = [] for i, emb in enumerate(self.disc_embeds): @@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module): cont_feat = self.cont_proj(x_cont) parts = [cont_feat, disc_feat, time_emb] + if pos_emb is not None: + parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) if cond_feat is not None: parts.append(cond_feat) feat = torch.cat(parts, dim=-1) feat = self.in_proj(feat) out, _ = self.backbone(feat) + out = self.post_norm(out) + out = out + self.post_ff(out) eps_pred = self.cont_head(out) if self.use_tanh_eps: eps_pred = torch.tanh(eps_pred) * self.eps_scale logits = [head(out) for head in self.disc_heads] return eps_pred, logits + + @staticmethod + def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor: + pos = torch.arange(seq_len, device=device).float().unsqueeze(1) + div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim)) + pe = torch.zeros(seq_len, dim, device=device) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return pe diff --git a/example/prepare_data.py b/example/prepare_data.py index 26eac22..e0427ac 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -5,7 +5,7 @@ import json from pathlib import Path from typing import Optional -from data_utils import compute_cont_stats, build_disc_stats, load_split +from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms from platform_utils import safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent @@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None): raise SystemExit("no train files found under %s" % str(DATA_GLOB)) data_paths = [safe_path(p) for p in data_paths] - mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) + transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows) + cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) ensure_dir(OUT_STATS.parent) with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: json.dump( { - "mean": mean, - "std": std, - "min": vmin, - "max": vmax, - "int_like": int_like, - "max_decimals": max_decimals, - "max_rows": max_rows, + "mean": cont_stats["mean"], + "std": cont_stats["std"], + "raw_mean": cont_stats["raw_mean"], + "raw_std": cont_stats["raw_std"], + "min": cont_stats["min"], + "max": cont_stats["max"], + "int_like": cont_stats["int_like"], + "max_decimals": cont_stats["max_decimals"], + "transform": cont_stats["transform"], + "skew": cont_stats["skew"], + "max_rows": cont_stats["max_rows"], }, f, indent=2, diff --git a/example/results/cont_stats.json b/example/results/cont_stats.json index 2faae18..c12ab06 100644 --- a/example/results/cont_stats.json +++ b/example/results/cont_stats.json @@ -45,7 +45,7 @@ "P3_PIT01": 668.9722350000003, "P4_HT_FD": -0.00010012580000000082, "P4_HT_LD": 35.41945000099953, - "P4_HT_PO": 35.4085699912002, + "P4_HT_PO": 2.6391372939040414, "P4_LD": 365.3833745803986, "P4_ST_FD": -6.5205999999999635e-06, "P4_ST_GOV": 17801.81294499996, @@ -100,7 +100,7 @@ "P3_PIT01": 1168.1071264424027, "P4_HT_FD": 0.002032582380617592, "P4_HT_LD": 33.212361169253235, - "P4_HT_PO": 31.187825914515162, + "P4_HT_PO": 1.7636196192459512, "P4_LD": 59.736616589045646, "P4_ST_FD": 0.0016428787127432496, "P4_ST_GOV": 1740.5997458128215, @@ -109,6 +109,116 @@ "P4_ST_PT01": 22.459962818146252, "P4_ST_TT01": 24.745939350221477 }, + "raw_mean": { + "P1_B2004": 0.08649086820000026, + "P1_B2016": 1.376161456000001, + "P1_B3004": 396.1861596906018, + "P1_B3005": 1037.372384413793, + "P1_B4002": 32.564872940799994, + "P1_B4005": 65.98190757240047, + "P1_B400B": 1925.0391570245934, + "P1_B4022": 36.28908066800001, + "P1_FCV02Z": 21.744261118400036, + "P1_FCV03D": 57.36123274140044, + "P1_FCV03Z": 58.05084519640002, + "P1_FT01": 184.18615112319728, + "P1_FT01Z": 851.8781750705965, + "P1_FT02": 1255.8572173544069, + "P1_FT02Z": 1925.0210755194114, + "P1_FT03": 269.37285885780574, + "P1_FT03Z": 1037.366172230601, + "P1_LCV01D": 11.228849048599963, + "P1_LCV01Z": 10.991610181600016, + "P1_LIT01": 396.8845311109994, + "P1_PCV01D": 53.80101618419986, + "P1_PCV01Z": 54.646640287199595, + "P1_PCV02Z": 12.017773542800072, + "P1_PIT01": 1.3692859488000075, + "P1_PIT02": 0.44459071260000227, + "P1_TIT01": 35.64255813999988, + "P1_TIT02": 36.44807823060023, + "P2_24Vdc": 28.0280019013999, + "P2_CO_rpm": 54105.64434999997, + "P2_HILout": 712.0588667425922, + "P2_MSD": 763.19324, + "P2_SIT01": 778.7769850000013, + "P2_SIT02": 778.7778935471981, + "P2_VT01": 11.914949448200044, + "P2_VXT02": -3.5267871940000175, + "P2_VXT03": -1.5520904921999914, + "P2_VYT02": 3.796112737600002, + "P2_VYT03": 6.121691697000018, + "P3_FIT01": 1168.2528800000014, + "P3_LCP01D": 4675.465239999989, + "P3_LCV01D": 7445.208720000017, + "P3_LIT01": 13728.982314999852, + "P3_PIT01": 668.9722350000003, + "P4_HT_FD": -0.00010012580000000082, + "P4_HT_LD": 35.41945000099953, + "P4_HT_PO": 35.4085699912002, + "P4_LD": 365.3833745803986, + "P4_ST_FD": -6.5205999999999635e-06, + "P4_ST_GOV": 17801.81294499996, + "P4_ST_LD": 329.83259218199964, + "P4_ST_PO": 330.1079461497967, + "P4_ST_PT01": 10047.679605000127, + "P4_ST_TT01": 27606.860070000155 + }, + "raw_std": { + "P1_B2004": 0.024492489898690458, + "P1_B2016": 0.12949272564759745, + "P1_B3004": 10.16264800653289, + "P1_B3005": 70.85697659109, + "P1_B4002": 0.7578213113008355, + "P1_B4005": 41.80065314991797, + "P1_B400B": 1176.6445547448632, + "P1_B4022": 0.8221115066487089, + "P1_FCV02Z": 39.11843197764177, + "P1_FCV03D": 7.889507447726625, + "P1_FCV03Z": 8.046068905945717, + "P1_FT01": 30.80117031882856, + "P1_FT01Z": 91.2786865433318, + "P1_FT02": 879.7163277334494, + "P1_FT02Z": 1176.6699531305117, + "P1_FT03": 38.18015841964941, + "P1_FT03Z": 70.73100774436428, + "P1_LCV01D": 3.3355655415557597, + "P1_LCV01Z": 3.386332233773545, + "P1_LIT01": 10.578714760104122, + "P1_PCV01D": 19.61567943613885, + "P1_PCV01Z": 19.778754467302086, + "P1_PCV02Z": 0.0048047978931599995, + "P1_PIT01": 0.0776614954053113, + "P1_PIT02": 0.44823231815652304, + "P1_TIT01": 0.5986678527528814, + "P1_TIT02": 1.1892341204521049, + "P2_24Vdc": 0.003208842504097816, + "P2_CO_rpm": 20.575477821507334, + "P2_HILout": 8.178853379908608, + "P2_MSD": 1.0, + "P2_SIT01": 3.894535775667256, + "P2_SIT02": 3.8824770788579395, + "P2_VT01": 0.06812990916670247, + "P2_VXT02": 0.43104157117568803, + "P2_VXT03": 0.26894251958139775, + "P2_VYT02": 0.4610907883207586, + "P2_VYT03": 0.30596429385075474, + "P3_FIT01": 1787.2987693141868, + "P3_LCP01D": 5145.4094261812725, + "P3_LCV01D": 6785.602781765096, + "P3_LIT01": 4060.915441872745, + "P3_PIT01": 1168.1071264424027, + "P4_HT_FD": 0.002032582380617592, + "P4_HT_LD": 33.21236116925323, + "P4_HT_PO": 31.18782591451516, + "P4_LD": 59.736616589045646, + "P4_ST_FD": 0.0016428787127432496, + "P4_ST_GOV": 1740.5997458128213, + "P4_ST_LD": 35.86633288900077, + "P4_ST_PO": 32.375012735256696, + "P4_ST_PT01": 22.45996281814625, + "P4_ST_TT01": 24.745939350221487 + }, "min": { "P1_B2004": 0.03051, "P1_B2016": 0.94729, @@ -329,5 +439,115 @@ "P4_ST_PT01": 2, "P4_ST_TT01": 2 }, + "transform": { + "P1_B2004": "none", + "P1_B2016": "none", + "P1_B3004": "none", + "P1_B3005": "none", + "P1_B4002": "none", + "P1_B4005": "none", + "P1_B400B": "none", + "P1_B4022": "none", + "P1_FCV02Z": "none", + "P1_FCV03D": "none", + "P1_FCV03Z": "none", + "P1_FT01": "none", + "P1_FT01Z": "none", + "P1_FT02": "none", + "P1_FT02Z": "none", + "P1_FT03": "none", + "P1_FT03Z": "none", + "P1_LCV01D": "none", + "P1_LCV01Z": "none", + "P1_LIT01": "none", + "P1_PCV01D": "none", + "P1_PCV01Z": "none", + "P1_PCV02Z": "none", + "P1_PIT01": "none", + "P1_PIT02": "none", + "P1_TIT01": "none", + "P1_TIT02": "none", + "P2_24Vdc": "none", + "P2_CO_rpm": "none", + "P2_HILout": "none", + "P2_MSD": "none", + "P2_SIT01": "none", + "P2_SIT02": "none", + "P2_VT01": "none", + "P2_VXT02": "none", + "P2_VXT03": "none", + "P2_VYT02": "none", + "P2_VYT03": "none", + "P3_FIT01": "none", + "P3_LCP01D": "none", + "P3_LCV01D": "none", + "P3_LIT01": "none", + "P3_PIT01": "none", + "P4_HT_FD": "none", + "P4_HT_LD": "none", + "P4_HT_PO": "log1p", + "P4_LD": "none", + "P4_ST_FD": "none", + "P4_ST_GOV": "none", + "P4_ST_LD": "none", + "P4_ST_PO": "none", + "P4_ST_PT01": "none", + "P4_ST_TT01": "none" + }, + "skew": { + "P1_B2004": -2.876938578031295e-05, + "P1_B2016": 2.014565216651284e-06, + "P1_B3004": 6.625985939357487e-06, + "P1_B3005": -9.917489652810193e-06, + "P1_B4002": 1.4641465884161855e-05, + "P1_B4005": -1.2370279269006856e-05, + "P1_B400B": -1.4116897198097317e-05, + "P1_B4022": 1.1162291352215598e-05, + "P1_FCV02Z": 2.532521501167817e-05, + "P1_FCV03D": 4.2517931711793e-06, + "P1_FCV03Z": 4.301856332440012e-06, + "P1_FT01": -1.3345735264961829e-05, + "P1_FT01Z": -4.2554413198354234e-05, + "P1_FT02": -1.0289230789249066e-05, + "P1_FT02Z": -1.4116856909216661e-05, + "P1_FT03": -4.341090521713463e-06, + "P1_FT03Z": -9.964308983887345e-06, + "P1_LCV01D": 2.541312481372867e-06, + "P1_LCV01Z": 2.5806433622267527e-06, + "P1_LIT01": 7.716120912717401e-06, + "P1_PCV01D": 2.113459306618771e-05, + "P1_PCV01Z": 2.0632832525407433e-05, + "P1_PCV02Z": 4.2639616636720384e-08, + "P1_PIT01": 2.079887220863843e-05, + "P1_PIT02": 5.003507344873546e-05, + "P1_TIT01": 9.553657000925262e-06, + "P1_TIT02": 2.1170357380515215e-05, + "P2_24Vdc": 2.735770838906968e-07, + "P2_CO_rpm": -8.124011608472296e-06, + "P2_HILout": -4.086282393330704e-06, + "P2_MSD": 0.0, + "P2_SIT01": -7.418240348817199e-06, + "P2_SIT02": -7.457826456660247e-06, + "P2_VT01": 1.247484205979928e-07, + "P2_VXT02": 6.53499778855353e-07, + "P2_VXT03": 5.32656056809399e-06, + "P2_VYT02": 9.483158480759724e-07, + "P2_VYT03": 2.128755351566922e-06, + "P3_FIT01": 2.2828575320599336e-05, + "P3_LCP01D": 1.3040993552131866e-05, + "P3_LCV01D": 3.781324885318626e-07, + "P3_LIT01": -7.824733758742217e-06, + "P3_PIT01": 3.210613447428708e-05, + "P4_HT_FD": 9.197840236384403e-05, + "P4_HT_LD": -2.4568845167931336e-08, + "P4_HT_PO": 3.997415489949367e-07, + "P4_LD": -6.253448074273654e-07, + "P4_ST_FD": 2.3472181460829935e-07, + "P4_ST_GOV": 2.494268873407866e-06, + "P4_ST_LD": 1.6692758818969547e-06, + "P4_ST_PO": 2.45129838870492e-06, + "P4_ST_PT01": 1.7637202837434092e-05, + "P4_ST_TT01": -1.9485876142550594e-05 + }, "max_rows": 50000 } \ No newline at end of file diff --git a/example/results/eval.json b/example/results/eval.json index 0ac9f8e..e3041f7 100644 --- a/example/results/eval.json +++ b/example/results/eval.json @@ -2,430 +2,430 @@ "rows": 1024, "continuous_summary": { "P1_B2004": { - "mean": 0.06806937500000017, - "std": 0.03529812909898432 + "mean": 0.06862171875, + "std": 0.035259175845515515 }, "P1_B2016": { - "mean": 1.4498257031250004, - "std": 0.5202270132660024 + "mean": 1.3826925292968744, + "std": 0.5121290702115694 }, "P1_B3004": { - "mean": 395.1642620507809, - "std": 18.951479974613218 + "mean": 396.2448581250001, + "std": 19.05011569291723 }, "P1_B3005": { - "mean": 1034.4064711914052, - "std": 106.23408116669583 + "mean": 1041.7244751367214, + "std": 103.92341871183335 }, "P1_B4002": { - "mean": 32.6127290039062, - "std": 0.799725330749142 + "mean": 32.59948117187498, + "std": 0.7957780591033785 }, "P1_B4005": { - "mean": 53.906249999999986, - "std": 49.87153584938548 + "mean": 49.414062500000014, + "std": 50.020996935373084 }, "P1_B400B": { - "mean": 1738.3725173242187, - "std": 1384.7459227832458 + "mean": 1712.5273975781247, + "std": 1390.5681557101923 }, "P1_B4022": { - "mean": 36.23618491210939, - "std": 1.7647217369622714 + "mean": 36.25454343749998, + "std": 1.7690532520721258 }, "P1_FCV02Z": { - "mean": 57.724750781249995, - "std": 48.54732220006803 + "mean": 58.78944615234368, + "std": 48.31409690193705 }, "P1_FCV03D": { - "mean": 56.75169301757812, - "std": 11.199559178422735 + "mean": 57.079870898437484, + "std": 11.206936616467804 }, "P1_FCV03Z": { - "mean": 55.566632607421845, - "std": 11.297853665660114 + "mean": 55.273929101562494, + "std": 11.23373431243193 }, "P1_FT01": { - "mean": 159.00522490234354, - "std": 108.28370690895133 + "mean": 154.35449926757812, + "std": 109.2483475479711 }, "P1_FT01Z": { - "mean": 760.7312570605471, - "std": 298.8224240896078 + "mean": 739.8981555468748, + "std": 301.436243374618 }, "P1_FT02": { - "mean": 1144.481825429689, - "std": 991.6954326635167 + "mean": 1269.5919779296892, + "std": 965.9342296967919 }, "P1_FT02Z": { - "mean": 1487.8695987890617, - "std": 1415.8144135499283 + "mean": 1535.3343196679693, + "std": 1413.824192981873 }, "P1_FT03": { - "mean": 256.8864012499994, - "std": 64.91894597175877 + "mean": 256.9996227343751, + "std": 64.93296018180462 }, "P1_FT03Z": { - "mean": 1015.7155878906256, - "std": 125.97944512388922 + "mean": 1017.4734681640617, + "std": 125.66707109475342 }, "P1_LCV01D": { - "mean": 13.136535000000006, - "std": 8.773087227459813 + "mean": 13.448744999999988, + "std": 8.717302095487849 }, "P1_LCV01Z": { - "mean": 9.427464931640635, - "std": 8.695702925059546 + "mean": 9.30471986328126, + "std": 8.66743535571038 }, "P1_LIT01": { - "mean": 400.05200666992226, - "std": 37.486039269525676 + "mean": 402.9207411230469, + "std": 36.97515497691232 }, "P1_PCV01D": { - "mean": 67.55343838867199, - "std": 37.17637545686828 + "mean": 68.44300272460937, + "std": 37.03521851450985 }, "P1_PCV01Z": { - "mean": 49.82481044921875, - "std": 34.92608216237906 + "mean": 48.90588702148442, + "std": 34.55315631185222 }, "P1_PCV02Z": { - "mean": 12.022767607421867, - "std": 0.018848054758257262 + "mean": 12.024047783203125, + "std": 0.018746495739474497 }, "P1_PIT01": { - "mean": 1.3068769824218769, - "std": 0.36610955142971885 + "mean": 1.3310731152343753, + "std": 0.3677450170593836 }, "P1_PIT02": { - "mean": 1.3398852539062494, - "std": 1.0838519802932796 + "mean": 1.2849893749999997, + "std": 1.0866531673857318 }, "P1_TIT01": { - "mean": 35.83285636718747, - "std": 1.074981874420601 + "mean": 35.77612789062499, + "std": 1.07625743741325 }, "P1_TIT02": { - "mean": 37.00902173828125, - "std": 2.607425943486915 + "mean": 36.935316699218745, + "std": 2.5870385657749857 }, "P2_24Vdc": { - "mean": 28.03205742187502, - "std": 0.01451220063756119 + "mean": 28.03155960937501, + "std": 0.014605826922890492 }, "P2_CO_rpm": { - "mean": 54089.67496093745, - "std": 83.65667448636566 + "mean": 54096.19540039063, + "std": 83.28739649017277 }, "P2_HILout": { - "mean": 708.0663378417961, - "std": 31.69619617693611 + "mean": 706.6639439062503, + "std": 31.61903880866903 }, "P2_MSD": { "mean": 763.19324, "std": 0.0 }, "P2_SIT01": { - "mean": 782.8947558593749, - "std": 18.15296528720419 + "mean": 782.7473828124998, + "std": 18.194800471814133 }, "P2_SIT02": { - "mean": 779.4267523730475, - "std": 18.52623485720013 + "mean": 779.1823742382821, + "std": 18.55564947581578 }, "P2_VT01": { - "mean": 11.942113281249998, - "std": 0.1212939383938054 + "mean": 11.938032617187508, + "std": 0.12190271288284543 }, "P2_VXT02": { - "mean": -3.3293065136718716, - "std": 1.0011199008755574 + "mean": -3.377015263671875, + "std": 0.9959783659170061 }, "P2_VXT03": { - "mean": -1.2753983105468742, - "std": 0.8392492416099594 + "mean": -1.1974353710937469, + "std": 0.8463535524317487 }, "P2_VYT02": { - "mean": 3.7446044042968776, - "std": 1.0195657446443727 + "mean": 3.7353859375000003, + "std": 1.0195706502676007 }, "P2_VYT03": { - "mean": 6.009656689453129, - "std": 0.9420566998843258 + "mean": 6.085932294921869, + "std": 0.931647679296692 }, "P3_FIT01": { - "mean": 2452.524746093747, - "std": 2700.720932270934 + "mean": 2537.117392578124, + "std": 2706.6282148304126 }, "P3_LCP01D": { - "mean": 5702.499999999998, - "std": 6810.097812156862 + "mean": 5947.764648437504, + "std": 6846.9763751739265 }, "P3_LCV01D": { - "mean": 9264.031249999993, - "std": 9029.897672702175 + "mean": 8222.312500000033, + "std": 9028.878178016928 }, "P3_LIT01": { - "mean": 12977.425781249998, - "std": 7169.404045518549 + "mean": 12808.792968750007, + "std": 7182.626250573974 }, "P3_PIT01": { - "mean": 1705.4522070312507, - "std": 1886.7900037587863 + "mean": 1619.4790917968737, + "std": 1876.6213805023137 }, "P4_HT_FD": { - "mean": 0.0004563964843749997, - "std": 0.009685606093521135 + "mean": 0.00012562500000000009, + "std": 0.009676935342083165 }, "P4_HT_LD": { - "mean": 31.12069875000001, - "std": 40.20561966234447 + "mean": 32.417695781249975, + "std": 40.51850858098518 }, "P4_HT_PO": { - "mean": 31.85621600585936, - "std": 40.3166843101482 + "mean": 31.12792625000001, + "std": 40.13557159439076 }, "P4_LD": { - "mean": 327.36375984375013, - "std": 128.53815520781228 + "mean": 330.763043642578, + "std": 129.48641524757082 }, "P4_ST_FD": { - "mean": -0.0010156933593749987, - "std": 0.007769561828982641 + "mean": -0.00027652343749999986, + "std": 0.007833558185460957 }, "P4_ST_GOV": { - "mean": 19994.500761718766, - "std": 6658.829447941799 + "mean": 20381.665029296888, + "std": 6603.011425275358 }, "P4_ST_LD": { - "mean": 345.45652514648435, - "std": 128.94975168344465 + "mean": 349.70025943359406, + "std": 129.20679188052912 }, "P4_ST_PO": { - "mean": 369.8679274511719, - "std": 125.46549145193221 + "mean": 364.372227851563, + "std": 125.4641046524738 }, "P4_ST_PT01": { - "mean": 10059.50247070312, - "std": 105.69814433833766 + "mean": 10049.883593749988, + "std": 106.29379548555222 }, "P4_ST_TT01": { - "mean": 27573.99121093751, - "std": 44.51705573286087 + "mean": 27574.791015625004, + "std": 44.68056797740681 } }, "continuous_error": { "P1_B2004": { - "mean_abs_err": 0.018421493200000097, - "std_abs_err": 0.01080563920029386 + "mean_abs_err": 0.01786914945000026, + "std_abs_err": 0.010766685946825057 }, "P1_B2016": { - "mean_abs_err": 0.0736642471249993, - "std_abs_err": 0.39073428761840495 + "mean_abs_err": 0.006531073296873302, + "std_abs_err": 0.38263634456397194 }, "P1_B3004": { - "mean_abs_err": 1.0218976398209065, - "std_abs_err": 8.788831968080327 + "mean_abs_err": 0.05869843439830902, + "std_abs_err": 8.88746768638434 }, "P1_B3005": { - "mean_abs_err": 2.965913222387826, - "std_abs_err": 35.37710457560583 + "mean_abs_err": 4.352090722928324, + "std_abs_err": 33.066442120743346 }, "P1_B4002": { - "mean_abs_err": 0.04785606310620949, - "std_abs_err": 0.04190401944830635 + "mean_abs_err": 0.034608231074983564, + "std_abs_err": 0.03795674780254299 }, "P1_B4005": { - "mean_abs_err": 12.075657572400488, - "std_abs_err": 8.07088269946751 + "mean_abs_err": 16.56784507240046, + "std_abs_err": 8.220343785455114 }, "P1_B400B": { - "mean_abs_err": 186.66663970037462, - "std_abs_err": 208.1013680383826 + "mean_abs_err": 212.51175944646866, + "std_abs_err": 213.92360096532911 }, "P1_B4022": { - "mean_abs_err": 0.05289575589062423, - "std_abs_err": 0.9426102303135625 + "mean_abs_err": 0.034537230500028215, + "std_abs_err": 0.9469417454234169 }, "P1_FCV02Z": { - "mean_abs_err": 35.98048966284996, - "std_abs_err": 9.428890222426269 + "mean_abs_err": 37.04518503394364, + "std_abs_err": 9.19566492429528 }, "P1_FCV03D": { - "mean_abs_err": 0.6095397238223228, - "std_abs_err": 3.3100517306961112 + "mean_abs_err": 0.2813618429629585, + "std_abs_err": 3.317429168741179 }, "P1_FCV03Z": { - "mean_abs_err": 2.484212588978174, - "std_abs_err": 3.251784759714397 + "mean_abs_err": 2.7769160948375244, + "std_abs_err": 3.1876654064862127 }, "P1_FT01": { - "mean_abs_err": 25.180926220853735, - "std_abs_err": 77.48253659012278 + "mean_abs_err": 29.831651855619157, + "std_abs_err": 78.44717722914254 }, "P1_FT01Z": { - "mean_abs_err": 91.14691801004938, - "std_abs_err": 207.54373754627602 + "mean_abs_err": 111.9800195237217, + "std_abs_err": 210.15755683128623 }, "P1_FT02": { - "mean_abs_err": 111.37539192471786, - "std_abs_err": 111.97910493006736 + "mean_abs_err": 13.73476057528228, + "std_abs_err": 86.21790196334257 }, "P1_FT02Z": { - "mean_abs_err": 437.1514767303497, - "std_abs_err": 239.14446041941687 + "mean_abs_err": 389.68675585144206, + "std_abs_err": 237.15423985136135 }, "P1_FT03": { - "mean_abs_err": 12.486457607806358, - "std_abs_err": 26.738787552109365 + "mean_abs_err": 12.373236123430615, + "std_abs_err": 26.752801762155208 }, "P1_FT03Z": { - "mean_abs_err": 21.650584339975467, - "std_abs_err": 55.248437379524944 + "mean_abs_err": 19.892704066539295, + "std_abs_err": 54.93606335038915 }, "P1_LCV01D": { - "mean_abs_err": 1.9076859514000422, - "std_abs_err": 5.437521685904054 + "mean_abs_err": 2.2198959514000247, + "std_abs_err": 5.38173655393209 }, "P1_LCV01Z": { - "mean_abs_err": 1.564145249959381, - "std_abs_err": 5.309370691286001 + "mean_abs_err": 1.6868903183187562, + "std_abs_err": 5.281103121936835 }, "P1_LIT01": { - "mean_abs_err": 3.167475558922888, - "std_abs_err": 26.907324509421557 + "mean_abs_err": 6.0362100120475475, + "std_abs_err": 26.3964402168082 }, "P1_PCV01D": { - "mean_abs_err": 13.752422204472133, - "std_abs_err": 17.560696020729427 + "mean_abs_err": 14.641986540409512, + "std_abs_err": 17.419539078371 }, "P1_PCV01Z": { - "mean_abs_err": 4.821829837980843, - "std_abs_err": 15.147327695076974 + "mean_abs_err": 5.740753265715178, + "std_abs_err": 14.774401844550134 }, "P1_PCV02Z": { - "mean_abs_err": 0.004994064621795857, - "std_abs_err": 0.014043256865097265 + "mean_abs_err": 0.006274240403053355, + "std_abs_err": 0.013941697846314497 }, "P1_PIT01": { - "mean_abs_err": 0.06240896637813065, - "std_abs_err": 0.28844805602440754 + "mean_abs_err": 0.03821283356563221, + "std_abs_err": 0.2900835216540723 }, "P1_PIT02": { - "mean_abs_err": 0.8952945413062471, - "std_abs_err": 0.6356196621367565 + "mean_abs_err": 0.8403986623999974, + "std_abs_err": 0.6384208492292087 }, "P1_TIT01": { - "mean_abs_err": 0.19029822718758993, - "std_abs_err": 0.4763140216677195 + "mean_abs_err": 0.13356975062511367, + "std_abs_err": 0.4775895846603687 }, "P1_TIT02": { - "mean_abs_err": 0.5609435076810172, - "std_abs_err": 1.41819182303481 + "mean_abs_err": 0.4872384686185143, + "std_abs_err": 1.3978044453228808 }, "P2_24Vdc": { - "mean_abs_err": 0.0040555204751200336, - "std_abs_err": 0.01130335813346338 + "mean_abs_err": 0.0035577079751085705, + "std_abs_err": 0.011396984418792677 }, "P2_CO_rpm": { - "mean_abs_err": 15.96938906252035, - "std_abs_err": 63.081196664858396 + "mean_abs_err": 9.448949609344709, + "std_abs_err": 62.71191866866543 }, "P2_HILout": { - "mean_abs_err": 3.99252890079606, - "std_abs_err": 23.5173427970275 + "mean_abs_err": 5.394922836341834, + "std_abs_err": 23.440185428760422 }, "P2_MSD": { "mean_abs_err": 0.0, "std_abs_err": 1.0 }, "P2_SIT01": { - "mean_abs_err": 4.117770859373536, - "std_abs_err": 14.258429511536933 + "mean_abs_err": 3.970397812498504, + "std_abs_err": 14.300264696146877 }, "P2_SIT02": { - "mean_abs_err": 0.6488588258494019, - "std_abs_err": 14.64375777834219 + "mean_abs_err": 0.40448069108401796, + "std_abs_err": 14.673172396957842 }, "P2_VT01": { - "mean_abs_err": 0.027163833049954178, - "std_abs_err": 0.05316402922710296 + "mean_abs_err": 0.023083168987463765, + "std_abs_err": 0.05377280371614296 }, "P2_VXT02": { - "mean_abs_err": 0.19748068032814592, - "std_abs_err": 0.5700783296998693 + "mean_abs_err": 0.1497719303281424, + "std_abs_err": 0.564936794741318 }, "P2_VXT03": { - "mean_abs_err": 0.27669218165311715, - "std_abs_err": 0.5703067220285616 + "mean_abs_err": 0.35465512110624453, + "std_abs_err": 0.577411032850351 }, "P2_VYT02": { - "mean_abs_err": 0.05150833330312432, - "std_abs_err": 0.5584749563236141 + "mean_abs_err": 0.06072680010000164, + "std_abs_err": 0.5584798619468421 }, "P2_VYT03": { - "mean_abs_err": 0.112035007546889, - "std_abs_err": 0.6360924060335711 + "mean_abs_err": 0.035759402078149094, + "std_abs_err": 0.6256833854459373 }, "P3_FIT01": { - "mean_abs_err": 1284.2718660937458, - "std_abs_err": 913.4221629567473 + "mean_abs_err": 1368.8645125781227, + "std_abs_err": 919.3294455162259 }, "P3_LCP01D": { - "mean_abs_err": 1027.0347600000096, - "std_abs_err": 1664.6883859755899 + "mean_abs_err": 1272.299408437515, + "std_abs_err": 1701.566948992654 }, "P3_LCV01D": { - "mean_abs_err": 1818.8225299999758, - "std_abs_err": 2244.294890937079 + "mean_abs_err": 777.1037800000158, + "std_abs_err": 2243.2753962518327 }, "P3_LIT01": { - "mean_abs_err": 751.5565337498538, - "std_abs_err": 3108.488603645804 + "mean_abs_err": 920.1893462498447, + "std_abs_err": 3121.7108087012293 }, "P3_PIT01": { - "mean_abs_err": 1036.4799720312503, - "std_abs_err": 718.6828773163836 + "mean_abs_err": 950.5068567968734, + "std_abs_err": 708.5142540599111 }, "P4_HT_FD": { - "mean_abs_err": 0.0005565222843750006, - "std_abs_err": 0.007653023712903543 + "mean_abs_err": 0.0002257508000000009, + "std_abs_err": 0.007644352961465573 }, "P4_HT_LD": { - "mean_abs_err": 4.298751250999523, - "std_abs_err": 6.993258493091233 + "mean_abs_err": 3.0017542197495573, + "std_abs_err": 7.306147411731949 }, "P4_HT_PO": { - "mean_abs_err": 3.5523539853408437, - "std_abs_err": 9.128858395633035 + "mean_abs_err": 4.280643741200194, + "std_abs_err": 8.947745679875602 }, "P4_LD": { - "mean_abs_err": 38.019614736648464, - "std_abs_err": 68.80153861876664 + "mean_abs_err": 34.6203309378206, + "std_abs_err": 69.74979865852518 }, "P4_ST_FD": { - "mean_abs_err": 0.0010091727593749989, - "std_abs_err": 0.006126683116239391 + "mean_abs_err": 0.0002700028374999999, + "std_abs_err": 0.006190679472717707 }, "P4_ST_GOV": { - "mean_abs_err": 2192.6878167188042, - "std_abs_err": 4918.229702128978 + "mean_abs_err": 2579.8520842969265, + "std_abs_err": 4862.411679462537 }, "P4_ST_LD": { - "mean_abs_err": 15.623932964484709, - "std_abs_err": 93.08341879444387 + "mean_abs_err": 19.867667251594412, + "std_abs_err": 93.34045899152835 }, "P4_ST_PO": { - "mean_abs_err": 39.759981301375205, - "std_abs_err": 93.0904787166755 + "mean_abs_err": 34.26428170176632, + "std_abs_err": 93.0890919172171 }, "P4_ST_PT01": { - "mean_abs_err": 11.822865702994022, - "std_abs_err": 83.2381815201914 + "mean_abs_err": 2.2039887498613098, + "std_abs_err": 83.83383266740597 }, "P4_ST_TT01": { - "mean_abs_err": 32.86885906264433, - "std_abs_err": 19.77111638263939 + "mean_abs_err": 32.06905437515161, + "std_abs_err": 19.934628627185322 } }, "discrete_invalid_counts": { @@ -455,5 +455,516 @@ "P3_LL": 0, "P4_HT_PS": 0, "P4_ST_PS": 0 + }, + "continuous_ks": { + "P1_B2004": 0.8106, + "P1_B2016": 0.5790015625, + "P1_B3004": 0.53125, + "P1_B3005": 0.4782921875, + "P1_B4002": 0.8105, + "P1_B4005": 0.79705, + "P1_B400B": 0.595653125, + "P1_B4022": 0.59375, + "P1_FCV02Z": 0.6123046875, + "P1_FCV03D": 0.50390625, + "P1_FCV03Z": 0.61328125, + "P1_FT01": 0.5380859375, + "P1_FT01Z": 0.5390625, + "P1_FT02": 0.6317859375, + "P1_FT02Z": 0.533153125, + "P1_FT03": 0.53125, + "P1_FT03Z": 0.587890625, + "P1_LCV01D": 0.5966796875, + "P1_LCV01Z": 0.611328125, + "P1_LIT01": 0.6025390625, + "P1_PCV01D": 0.5791015625, + "P1_PCV01Z": 0.685546875, + "P1_PCV02Z": 0.568359375, + "P1_PIT01": 0.543871875, + "P1_PIT02": 0.5101953125, + "P1_TIT01": 0.501953125, + "P1_TIT02": 0.6396484375, + "P2_24Vdc": 0.6011625, + "P2_CO_rpm": 0.532503125, + "P2_HILout": 0.524140625, + "P2_MSD": 1.0, + "P2_SIT01": 0.6023390625, + "P2_SIT02": 0.505659375, + "P2_VT01": 0.5654296875, + "P2_VXT02": 0.5615234375, + "P2_VXT03": 0.5126953125, + "P2_VYT02": 0.52734375, + "P2_VYT03": 0.583984375, + "P3_FIT01": 0.52734375, + "P3_LCP01D": 0.568359375, + "P3_LCV01D": 0.529296875, + "P3_LIT01": 0.533203125, + "P3_PIT01": 0.5654296875, + "P4_HT_FD": 0.5009390625, + "P4_HT_LD": 0.609375, + "P4_HT_PO": 0.625, + "P4_LD": 0.6279296875, + "P4_ST_FD": 0.5097921875, + "P4_ST_GOV": 0.577675, + "P4_ST_LD": 0.5361328125, + "P4_ST_PO": 0.51953125, + "P4_ST_PT01": 0.5004765625, + "P4_ST_TT01": 0.595703125 + }, + "continuous_quantile_diff": { + "P1_B2004": { + "q05_diff": 0.0, + "q25_diff": 0.0707, + "q50_diff": 0.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B2016": { + "q05_diff": 0.20477000000000012, + "q25_diff": 0.35475999999999996, + "q50_diff": 0.4335500000000001, + "q75_diff": 0.5400499999999999, + "q95_diff": 0.4267000000000001 + }, + "P1_B3004": { + "q05_diff": 14.80313000000001, + "q25_diff": 19.27872000000002, + "q50_diff": 19.27872000000002, + "q75_diff": 18.877499999999998, + "q95_diff": 18.877499999999998 + }, + "P1_B3005": { + "q05_diff": 107.27929999999992, + "q25_diff": 107.27929999999992, + "q50_diff": 113.12176000000011, + "q75_diff": 113.12176000000011, + "q95_diff": 0.0 + }, + "P1_B4002": { + "q05_diff": 0.0, + "q25_diff": 1.6555000000000035, + "q50_diff": 1.6555000000000035, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B4005": { + "q05_diff": 0.0, + "q25_diff": 100.0, + "q50_diff": 100.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B400B": { + "q05_diff": 8.937850000000001, + "q25_diff": 2803.04775, + "q50_diff": 22.87377000000015, + "q75_diff": 18.10082999999986, + "q95_diff": 12.064449999999852 + }, + "P1_B4022": { + "q05_diff": 0.741570000000003, + "q25_diff": 2.147590000000001, + "q50_diff": 2.5206700000000026, + "q75_diff": 0.883100000000006, + "q95_diff": 0.5329200000000043 + }, + "P1_FCV02Z": { + "q05_diff": 0.015249999999999986, + "q25_diff": 0.015249999999999986, + "q50_diff": 99.09821000000001, + "q75_diff": 99.09821000000001, + "q95_diff": 0.030520000000009873 + }, + "P1_FCV03D": { + "q05_diff": 4.971059999999994, + "q25_diff": 5.407429999999998, + "q50_diff": 5.7296499999999995, + "q75_diff": 16.195880000000002, + "q95_diff": 1.1158599999999979 + }, + "P1_FCV03Z": { + "q05_diff": 5.249020000000002, + "q25_diff": 5.607600000000005, + "q50_diff": 5.729670000000006, + "q75_diff": 16.738889999999998, + "q95_diff": 1.1749300000000034 + }, + "P1_FT01": { + "q05_diff": 120.43077000000001, + "q25_diff": 130.32031, + "q50_diff": 83.73258999999999, + "q75_diff": 78.01056, + "q95_diff": 34.52304000000001 + }, + "P1_FT01Z": { + "q05_diff": 387.17252, + "q25_diff": 407.15109, + "q50_diff": 187.60754000000009, + "q75_diff": 174.5613400000001, + "q95_diff": 75.40979000000004 + }, + "P1_FT02": { + "q05_diff": 1.7166099999999993, + "q25_diff": 1961.32641, + "q50_diff": 31.47144000000003, + "q75_diff": 24.98621000000003, + "q95_diff": 16.78478999999993 + }, + "P1_FT02Z": { + "q05_diff": 8.937850000000001, + "q25_diff": 2803.04775, + "q50_diff": 22.87377000000015, + "q75_diff": 18.10082999999986, + "q95_diff": 12.064449999999852 + }, + "P1_FT03": { + "q05_diff": 57.21861999999999, + "q25_diff": 58.36301999999998, + "q50_diff": 70.76033999999999, + "q75_diff": 69.23453999999998, + "q95_diff": 3.8145700000000033 + }, + "P1_FT03Z": { + "q05_diff": 130.45844, + "q25_diff": 133.06768999999997, + "q50_diff": 120.01207999999997, + "q75_diff": 116.53325999999993, + "q95_diff": 6.377929999999878 + }, + "P1_LCV01D": { + "q05_diff": 4.2898, + "q25_diff": 5.0423599999999995, + "q50_diff": 11.84572, + "q75_diff": 9.60605, + "q95_diff": 4.67375 + }, + "P1_LCV01Z": { + "q05_diff": 4.40978, + "q25_diff": 5.241390000000001, + "q50_diff": 5.95092, + "q75_diff": 9.460459999999998, + "q95_diff": 4.226689999999998 + }, + "P1_LIT01": { + "q05_diff": 34.6062, + "q25_diff": 38.28658999999999, + "q50_diff": 35.42405000000002, + "q75_diff": 33.32828000000001, + "q95_diff": 29.698980000000006 + }, + "P1_PCV01D": { + "q05_diff": 10.95326, + "q25_diff": 15.403260000000003, + "q50_diff": 56.51593, + "q75_diff": 52.88774, + "q95_diff": 46.32521 + }, + "P1_PCV01Z": { + "q05_diff": 10.925290000000004, + "q25_diff": 15.548700000000004, + "q50_diff": 18.63098, + "q75_diff": 51.9104, + "q95_diff": 45.50171 + }, + "P1_PCV02Z": { + "q05_diff": 0.007629999999998915, + "q25_diff": 0.007629999999998915, + "q50_diff": 0.021849999999998815, + "q75_diff": 0.0228900000000003, + "q95_diff": 0.0228900000000003 + }, + "P1_PIT01": { + "q05_diff": 0.27142999999999995, + "q25_diff": 0.3475299999999999, + "q50_diff": 0.3565400000000001, + "q75_diff": 0.32497, + "q95_diff": 0.27395000000000014 + }, + "P1_PIT02": { + "q05_diff": 0.03356999999999999, + "q25_diff": 0.10451999999999997, + "q50_diff": 2.0233700000000003, + "q75_diff": 2.06985, + "q95_diff": 2.06909 + }, + "P1_TIT01": { + "q05_diff": 0.07629999999999626, + "q25_diff": 0.4272399999999976, + "q50_diff": 0.991819999999997, + "q75_diff": 0.5187900000000027, + "q95_diff": 0.09155000000000513 + }, + "P1_TIT02": { + "q05_diff": 0.09155000000000513, + "q25_diff": 0.4577600000000004, + "q50_diff": 1.2207000000000008, + "q75_diff": 3.2806399999999982, + "q95_diff": 1.2207099999999969 + }, + "P2_24Vdc": { + "q05_diff": 0.00946000000000069, + "q25_diff": 0.012100000000000222, + "q50_diff": 0.014739999999999753, + "q75_diff": 0.013829999999998677, + "q95_diff": 0.010680000000000689 + }, + "P2_CO_rpm": { + "q05_diff": 68.91000000000349, + "q25_diff": 88.13999999999942, + "q50_diff": 67.0, + "q75_diff": 54.0, + "q95_diff": 37.0 + }, + "P2_HILout": { + "q05_diff": 21.350099999999998, + "q25_diff": 31.207269999999994, + "q50_diff": 35.31494000000009, + "q75_diff": 21.63695999999993, + "q95_diff": 14.538569999999936 + }, + "P2_MSD": { + "q05_diff": 0.0, + "q25_diff": 0.0, + "q50_diff": 0.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P2_SIT01": { + "q05_diff": 12.580000000000041, + "q25_diff": 16.710000000000036, + "q50_diff": 17.850000000000023, + "q75_diff": 16.899999999999977, + "q95_diff": 13.25 + }, + "P2_SIT02": { + "q05_diff": 12.788270000000011, + "q25_diff": 16.65368000000001, + "q50_diff": 14.857610000000022, + "q75_diff": 16.707580000000007, + "q95_diff": 13.430229999999938 + }, + "P2_VT01": { + "q05_diff": 0.015200000000000102, + "q25_diff": 0.046990000000000975, + "q50_diff": 0.13512000000000057, + "q75_diff": 0.07174000000000014, + "q95_diff": 0.03888000000000069 + }, + "P2_VXT02": { + "q05_diff": 0.08729999999999993, + "q25_diff": 0.2607000000000004, + "q50_diff": 0.6693000000000002, + "q75_diff": 0.9367000000000001, + "q95_diff": 0.7728999999999999 + }, + "P2_VXT03": { + "q05_diff": 0.07350000000000012, + "q25_diff": 0.18210000000000015, + "q50_diff": 0.4036000000000002, + "q75_diff": 1.0967, + "q95_diff": 0.9957 + }, + "P2_VYT02": { + "q05_diff": 0.3511000000000002, + "q25_diff": 0.5322, + "q50_diff": 0.9708999999999999, + "q75_diff": 0.6384999999999996, + "q95_diff": 0.4569000000000001 + }, + "P2_VYT03": { + "q05_diff": 0.6910999999999996, + "q25_diff": 0.8129999999999997, + "q50_diff": 0.8028000000000004, + "q75_diff": 0.5364000000000004, + "q95_diff": 0.41800000000000015 + }, + "P3_FIT01": { + "q05_diff": 2.0, + "q25_diff": 4.0, + "q50_diff": 76.0, + "q75_diff": 2735.0, + "q95_diff": 382.0 + }, + "P3_LCP01D": { + "q05_diff": 8.0, + "q25_diff": 56.0, + "q50_diff": 1760.0, + "q75_diff": 4112.0, + "q95_diff": 216.0 + }, + "P3_LCV01D": { + "q05_diff": 16.0, + "q25_diff": 336.0, + "q50_diff": 9488.0, + "q75_diff": 3376.0, + "q95_diff": 1584.0 + }, + "P3_LIT01": { + "q05_diff": 1310.0, + "q25_diff": 4632.0, + "q50_diff": 6346.0, + "q75_diff": 3409.0, + "q95_diff": 473.0 + }, + "P3_PIT01": { + "q05_diff": 2.0, + "q25_diff": 3.0, + "q50_diff": 4.0, + "q75_diff": 2855.0, + "q95_diff": 259.0 + }, + "P4_HT_FD": { + "q05_diff": 0.00863, + "q25_diff": 0.00963, + "q50_diff": 0.007980000000000001, + "q75_diff": 0.00971, + "q95_diff": 0.00881 + }, + "P4_HT_LD": { + "q05_diff": 0.0, + "q25_diff": 0.0, + "q50_diff": 54.74537, + "q75_diff": 14.424189999999996, + "q95_diff": 6.669560000000004 + }, + "P4_HT_PO": { + "q05_diff": 0.0, + "q25_diff": 1.35638, + "q50_diff": 43.402800000000006, + "q75_diff": 15.426150000000007, + "q95_diff": 7.179570000000012 + }, + "P4_LD": { + "q05_diff": 38.17633000000001, + "q25_diff": 89.59051, + "q50_diff": 137.02618, + "q75_diff": 84.97899999999998, + "q95_diff": 39.27954 + }, + "P4_ST_FD": { + "q05_diff": 0.00547, + "q25_diff": 0.00697, + "q50_diff": 0.00664, + "q75_diff": 0.00685, + "q95_diff": 0.0054800000000000005 + }, + "P4_ST_GOV": { + "q05_diff": 2299.0, + "q25_diff": 4178.0, + "q50_diff": 7730.490000000002, + "q75_diff": 7454.919999999998, + "q95_diff": 5812.0 + }, + "P4_ST_LD": { + "q05_diff": 39.333740000000034, + "q25_diff": 76.64203000000003, + "q50_diff": 100.76677999999998, + "q75_diff": 137.83997, + "q95_diff": 105.36016999999998 + }, + "P4_ST_PO": { + "q05_diff": 43.02301, + "q25_diff": 78.10687000000001, + "q50_diff": 136.67437999999999, + "q75_diff": 139.8797, + "q95_diff": 107.81970000000001 + }, + "P4_ST_PT01": { + "q05_diff": 83.0, + "q25_diff": 89.54999999999927, + "q50_diff": 61.399999999999636, + "q75_diff": 104.46999999999935, + "q95_diff": 76.94000000000051 + }, + "P4_ST_TT01": { + "q05_diff": 14.0, + "q25_diff": 48.0, + "q50_diff": 88.0, + "q75_diff": 2.0, + "q95_diff": 2.0 + } + }, + "continuous_lag1_diff": { + "P1_B2004": 0.9718697145204541, + "P1_B2016": 1.0096989619256902, + "P1_B3004": 0.987235023697155, + "P1_B3005": 1.0293153257905971, + "P1_B4002": 1.0101720557598692, + "P1_B4005": 1.0069799798786847, + "P1_B400B": 1.0100831396536611, + "P1_B4022": 1.005260911857099, + "P1_FCV02Z": 1.0208150011699488, + "P1_FCV03D": 0.9677639878142119, + "P1_FCV03Z": 1.0448432960637175, + "P1_FT01": 1.0174834364370597, + "P1_FT01Z": 0.9623633857501661, + "P1_FT02": 0.9492143748470167, + "P1_FT02Z": 1.0248788433686373, + "P1_FT03": 0.9967197043976962, + "P1_FT03Z": 1.0042718034976392, + "P1_LCV01D": 0.9767924561102216, + "P1_LCV01Z": 1.0238429714137112, + "P1_LIT01": 0.994524637441244, + "P1_PCV01D": 0.9731502416003561, + "P1_PCV01Z": 0.9987281311850285, + "P1_PCV02Z": 0.5767325423170566, + "P1_PIT01": 1.0052067307895398, + "P1_PIT02": 1.07011185942615, + "P1_TIT01": 1.0555120205346828, + "P1_TIT02": 0.9917823846962297, + "P2_24Vdc": 0.005526132896643308, + "P2_CO_rpm": 0.40287161666960586, + "P2_HILout": 0.25625640743099676, + "P2_MSD": 0.0, + "P2_SIT01": 0.7010364996672233, + "P2_SIT02": 0.7090423064483174, + "P2_VT01": 0.8506494578213937, + "P2_VXT02": 0.8094431207350834, + "P2_VXT03": 0.8417667674176789, + "P2_VYT02": 0.8415060810530584, + "P2_VYT03": 0.8550842087621501, + "P3_FIT01": 0.9816722571345234, + "P3_LCP01D": 0.9961743411760093, + "P3_LCV01D": 0.9663083457613558, + "P3_LIT01": 1.0191146983629609, + "P3_PIT01": 0.9827441734544354, + "P4_HT_FD": 0.26359661930055545, + "P4_HT_LD": 1.0050543113208599, + "P4_HT_PO": 1.0522741078670352, + "P4_LD": 0.9692264478458912, + "P4_ST_FD": 0.37863299269548845, + "P4_ST_GOV": 1.0392419101031738, + "P4_ST_LD": 0.9837987258388408, + "P4_ST_PO": 1.0089254379868162, + "P4_ST_PT01": 1.0237355221468971, + "P4_ST_TT01": 0.9905047944110966 + }, + "discrete_jsd": { + "P1_FCV01D": 0.10738201625126076, + "P1_FCV01Z": 0.24100377687195113, + "P1_FCV02D": 0.059498960127610634, + "P1_PCV02D": 0.0, + "P1_PP01AD": 0.0, + "P1_PP01AR": 0.0, + "P1_PP01BD": 0.0, + "P1_PP01BR": 0.0, + "P1_PP02D": 0.0, + "P1_PP02R": 0.0, + "P1_STSP": 0.0, + "P2_ASD": 0.0, + "P2_AutoGO": 0.0, + "P2_Emerg": 0.0, + "P2_ManualGO": 0.0, + "P2_OnOff": 0.0, + "P2_RTR": 0.0, + "P2_TripEx": 0.0, + "P2_VTR01": 0.0, + "P2_VTR02": 0.0, + "P2_VTR03": 0.0, + "P2_VTR04": 0.0, + "P3_LH": 0.0, + "P3_LL": 0.0, + "P4_HT_PS": 0.0, + "P4_ST_PS": 0.0 } } \ No newline at end of file diff --git a/example/run_pipeline.py b/example/run_pipeline.py index 0f1fd69..c93974d 100644 --- a/example/run_pipeline.py +++ b/example/run_pipeline.py @@ -48,6 +48,8 @@ def main(): seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) clip_k = cfg.get("clip_k", 5.0) + data_glob = cfg.get("data_glob", "") + data_path = cfg.get("data_path", "") run([sys.executable, str(base_dir / "prepare_data.py")]) run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device]) run( @@ -70,7 +72,11 @@ def main(): "--use-ema", ] ) - run([sys.executable, str(base_dir / "evaluate_generated.py")]) + ref = data_glob if data_glob else data_path + if ref: + run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) + else: + run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "plot_loss.py")]) diff --git a/example/sample.py b/example/sample.py index 14a5863..1b5f85d 100755 --- a/example/sample.py +++ b/example/sample.py @@ -47,6 +47,13 @@ def main(): cond_dim = int(cfg.get("cond_dim", 32)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) + model_time_dim = int(cfg.get("model_time_dim", 64)) + model_hidden_dim = int(cfg.get("model_hidden_dim", 256)) + model_num_layers = int(cfg.get("model_num_layers", 1)) + model_dropout = float(cfg.get("model_dropout", 0.0)) + model_ff_mult = int(cfg.get("model_ff_mult", 2)) + model_pos_dim = int(cfg.get("model_pos_dim", 64)) + model_use_pos = bool(cfg.get("model_use_pos_embed", True)) split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") @@ -67,6 +74,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=model_time_dim, + hidden_dim=model_hidden_dim, + num_layers=model_num_layers, + dropout=model_dropout, + ff_mult=model_ff_mult, + pos_dim=model_pos_dim, + use_pos_embed=model_use_pos, cond_vocab_size=cond_vocab_size, cond_dim=cond_dim, use_tanh_eps=use_tanh_eps, diff --git a/example/train.py b/example/train.py index 13699c7..467eb76 100755 --- a/example/train.py +++ b/example/train.py @@ -49,8 +49,17 @@ DEFAULTS = { "use_condition": True, "condition_type": "file_id", "cond_dim": 32, - "use_tanh_eps": True, + "use_tanh_eps": False, "eps_scale": 1.0, + "model_time_dim": 128, + "model_hidden_dim": 512, + "model_num_layers": 2, + "model_dropout": 0.1, + "model_ff_mult": 2, + "model_pos_dim": 64, + "model_use_pos_embed": True, + "disc_mask_scale": 0.9, + "shuffle_buffer": 256, } @@ -144,6 +153,7 @@ def main(): stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] + transforms = stats.get("transform", {}) vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] @@ -164,6 +174,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=int(config.get("model_time_dim", 64)), + hidden_dim=int(config.get("model_hidden_dim", 256)), + num_layers=int(config.get("model_num_layers", 1)), + dropout=float(config.get("model_dropout", 0.0)), + ff_mult=int(config.get("model_ff_mult", 2)), + pos_dim=int(config.get("model_pos_dim", 64)), + use_pos_embed=bool(config.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size, cond_dim=int(config.get("cond_dim", 32)), use_tanh_eps=bool(config.get("use_tanh_eps", False)), @@ -198,6 +215,8 @@ def main(): seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), return_file_id=use_condition, + transforms=transforms, + shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): if use_condition: @@ -215,7 +234,13 @@ def main(): x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) - x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) + x_disc_t, mask = q_sample_discrete( + x_disc, + t, + mask_tokens, + int(config["timesteps"]), + mask_scale=float(config.get("disc_mask_scale", 1.0)), + ) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)