From ff12324560f1dd544569026c06dd8b78e884f839 Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Fri, 23 Jan 2026 15:06:52 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=9E=E7=BB=AD=E5=9E=8B=E7=89=B9=E5=BE=81?= =?UTF-8?q?=E5=9C=A8=E6=97=B6=E8=AE=B8=E7=9B=B8=E5=85=B3=E6=80=A7=E4=B8=8A?= =?UTF-8?q?=E7=9A=84=E4=B8=8D=E8=B6=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/README.md | 1 + example/config.json | 17 +- example/data_utils.py | 216 +++++++++++-- example/evaluate_generated.py | 147 ++++++++- example/export_samples.py | 15 +- example/hybrid_diffusion.py | 43 ++- example/prepare_data.py | 23 +- example/results/cont_stats.json | 224 ++++++++++++- example/results/eval.json | 543 +++++++++++++++++++++++++++++++- example/run_pipeline.py | 8 +- example/sample.py | 14 + example/train.py | 29 +- 12 files changed, 1212 insertions(+), 68 deletions(-) diff --git a/example/README.md b/example/README.md index 32ffe81..b88330f 100644 --- a/example/README.md +++ b/example/README.md @@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto - Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training. - Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. - Export now clamps continuous features to training min/max and preserves integer/decimal precision. +- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json). - `` tokens are replaced by the most frequent token for each discrete column at export. - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. diff --git a/example/config.json b/example/config.json index d5feaf6..58baa92 100644 --- a/example/config.json +++ b/example/config.json @@ -6,11 +6,11 @@ "vocab_path": "./results/disc_vocab.json", "out_dir": "./results", "device": "auto", - "timesteps": 400, + "timesteps": 600, "batch_size": 128, "seq_len": 128, - "epochs": 8, - "max_batches": 3000, + "epochs": 10, + "max_batches": 4000, "lambda": 0.5, "lr": 0.0005, "seed": 1337, @@ -23,8 +23,17 @@ "use_condition": true, "condition_type": "file_id", "cond_dim": 32, - "use_tanh_eps": true, + "use_tanh_eps": false, "eps_scale": 1.0, + "model_time_dim": 128, + "model_hidden_dim": 512, + "model_num_layers": 2, + "model_dropout": 0.1, + "model_ff_mult": 2, + "model_pos_dim": 64, + "model_use_pos_embed": true, + "disc_mask_scale": 0.9, + "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/data_utils.py b/example/data_utils.py index a38b221..aa36195 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -4,6 +4,8 @@ import csv import gzip import json +import math +import random from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]: yield row -def compute_cont_stats( +def _stream_basic_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, -) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]: - """Streaming mean/std (Welford) + min/max + int/precision metadata.""" - count = 0 +): + """Streaming stats with mean/M2/M3 + min/max + int/precision metadata.""" + count = {c: 0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} + m3 = {c: 0.0 for c in cont_cols} vmin = {c: float("inf") for c in cont_cols} vmax = {c: float("-inf") for c in cont_cols} int_like = {c: True for c in cont_cols} max_decimals = {c: 0 for c in cont_cols} + all_pos = {c: True for c in cont_cols} for i, row in enumerate(iter_rows(path)): - count += 1 for c in cont_cols: raw = row[c] if raw is None or raw == "": continue x = float(raw) - delta = x - mean[c] - mean[c] += delta / count - delta2 = x - mean[c] - m2[c] += delta * delta2 + if x <= 0: + all_pos[c] = False if x < vmin[c]: vmin[c] = x if x > vmax[c]: vmax[c] = x if int_like[c] and abs(x - round(x)) > 1e-9: int_like[c] = False - # track decimal places from raw string if possible - if "e" in raw or "E" in raw: - # scientific notation, skip precision inference - continue - if "." in raw: + if "e" not in raw and "E" not in raw and "." in raw: dec = raw.split(".", 1)[1].rstrip("0") if len(dec) > max_decimals[c]: max_decimals[c] = len(dec) + + n = count[c] + 1 + delta = x - mean[c] + delta_n = delta / n + term1 = delta * delta_n * (n - 1) + m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c] + m2[c] += term1 + mean[c] += delta_n + count[c] = n + if max_rows is not None and i + 1 >= max_rows: break + # finalize std/skew std = {} + skew = {} for c in cont_cols: - if count > 1: - var = m2[c] / (count - 1) + n = count[c] + if n > 1: + var = m2[c] / (n - 1) else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 - # replace infs if column had no valid values + if n > 2 and m2[c] > 0: + skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5) + else: + skew[c] = 0.0 + for c in cont_cols: if vmin[c] == float("inf"): vmin[c] = 0.0 if vmax[c] == float("-inf"): vmax[c] = 0.0 - return mean, std, vmin, vmax, int_like, max_decimals + + return { + "count": count, + "mean": mean, + "std": std, + "m2": m2, + "m3": m3, + "min": vmin, + "max": vmax, + "int_like": int_like, + "max_decimals": max_decimals, + "skew": skew, + "all_pos": all_pos, + } + + +def choose_cont_transforms( + path: Union[str, List[str]], + cont_cols: List[str], + max_rows: Optional[int] = None, + skew_threshold: float = 1.5, + range_ratio_threshold: float = 1e3, +): + """Pick per-column transform (currently log1p or none) based on skew/range.""" + stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows) + transforms = {} + for c in cont_cols: + if not stats["all_pos"][c]: + transforms[c] = "none" + continue + skew = abs(stats["skew"][c]) + vmin = stats["min"][c] + vmax = stats["max"][c] + ratio = (vmax / vmin) if vmin > 0 else 0.0 + if skew >= skew_threshold or ratio >= range_ratio_threshold: + transforms[c] = "log1p" + else: + transforms[c] = "none" + return transforms, stats + + +def compute_cont_stats( + path: Union[str, List[str]], + cont_cols: List[str], + max_rows: Optional[int] = None, + transforms: Optional[Dict[str, str]] = None, +): + """Compute stats on (optionally transformed) values. Returns raw + transformed stats.""" + # First pass (raw) for metadata and raw mean/std + raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows) + + # Optional transform selection + if transforms is None: + transforms = {c: "none" for c in cont_cols} + + # Second pass for transformed mean/std + count = {c: 0 for c in cont_cols} + mean = {c: 0.0 for c in cont_cols} + m2 = {c: 0.0 for c in cont_cols} + for i, row in enumerate(iter_rows(path)): + for c in cont_cols: + raw_val = row[c] + if raw_val is None or raw_val == "": + continue + x = float(raw_val) + if transforms.get(c) == "log1p": + if x < 0: + x = 0.0 + x = math.log1p(x) + n = count[c] + 1 + delta = x - mean[c] + mean[c] += delta / n + delta2 = x - mean[c] + m2[c] += delta * delta2 + count[c] = n + if max_rows is not None and i + 1 >= max_rows: + break + + std = {} + for c in cont_cols: + if count[c] > 1: + var = m2[c] / (count[c] - 1) + else: + var = 0.0 + std[c] = var ** 0.5 if var > 0 else 1.0 + + return { + "mean": mean, + "std": std, + "raw_mean": raw["mean"], + "raw_std": raw["std"], + "min": raw["min"], + "max": raw["max"], + "int_like": raw["int_like"], + "max_decimals": raw["max_decimals"], + "transform": transforms, + "skew": raw["skew"], + "all_pos": raw["all_pos"], + "max_rows": max_rows, + } def build_vocab( @@ -130,8 +243,19 @@ def build_disc_stats( return vocab, top_token -def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]): +def normalize_cont( + x, + cont_cols: List[str], + mean: Dict[str, float], + std: Dict[str, float], + transforms: Optional[Dict[str, str]] = None, +): import torch + + if transforms: + for i, c in enumerate(cont_cols): + if transforms.get(c) == "log1p": + x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0)) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t @@ -148,19 +272,34 @@ def windowed_batches( seq_len: int, max_batches: Optional[int] = None, return_file_id: bool = False, + transforms: Optional[Dict[str, str]] = None, + shuffle_buffer: int = 0, ): import torch batch_cont = [] batch_disc = [] batch_file = [] + buffer = [] seq_cont = [] seq_disc = [] - def flush_seq(): - nonlocal seq_cont, seq_disc, batch_cont, batch_disc + def flush_seq(file_id: int): + nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file if len(seq_cont) == seq_len: - batch_cont.append(seq_cont) - batch_disc.append(seq_disc) + if shuffle_buffer and shuffle_buffer > 0: + buffer.append((list(seq_cont), list(seq_disc), file_id)) + if len(buffer) >= shuffle_buffer: + idx = random.randrange(len(buffer)) + seq_c, seq_d, seq_f = buffer.pop(idx) + batch_cont.append(seq_c) + batch_disc.append(seq_d) + if return_file_id: + batch_file.append(seq_f) + else: + batch_cont.append(seq_cont) + batch_disc.append(seq_disc) + if return_file_id: + batch_file.append(file_id) seq_cont = [] seq_disc = [] @@ -173,13 +312,11 @@ def windowed_batches( seq_cont.append(cont_row) seq_disc.append(disc_row) if len(seq_cont) == seq_len: - flush_seq() - if return_file_id: - batch_file.append(file_id) + flush_seq(file_id) if len(batch_cont) == batch_size: x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) - x_cont = normalize_cont(x_cont, cont_cols, mean, std) + x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file @@ -195,4 +332,29 @@ def windowed_batches( seq_cont = [] seq_disc = [] + # Flush any remaining buffered sequences + if shuffle_buffer and buffer: + random.shuffle(buffer) + for seq_c, seq_d, seq_f in buffer: + batch_cont.append(seq_c) + batch_disc.append(seq_d) + if return_file_id: + batch_file.append(seq_f) + if len(batch_cont) == batch_size: + import torch + x_cont = torch.tensor(batch_cont, dtype=torch.float32) + x_disc = torch.tensor(batch_disc, dtype=torch.long) + x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) + if return_file_id: + x_file = torch.tensor(batch_file, dtype=torch.long) + yield x_cont, x_disc, x_file + else: + yield x_cont, x_disc + batch_cont = [] + batch_disc = [] + batch_file = [] + batches_yielded += 1 + if max_batches is not None and batches_yielded >= max_batches: + return + # Drop last partial batch for simplicity diff --git a/example/evaluate_generated.py b/example/evaluate_generated.py index b5431e8..7ed87e6 100644 --- a/example/evaluate_generated.py +++ b/example/evaluate_generated.py @@ -5,8 +5,9 @@ import argparse import csv import gzip import json +import math from pathlib import Path -from typing import Dict, Tuple +from typing import Dict, Tuple, List, Optional def load_json(path: str) -> Dict: @@ -28,6 +29,8 @@ def parse_args(): parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--out", default=str(base_dir / "results" / "eval.json")) + parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics") + parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics") return parser.parse_args() @@ -55,6 +58,62 @@ def finalize_stats(stats): return out +def js_divergence(p, q, eps: float = 1e-12) -> float: + p = [max(x, eps) for x in p] + q = [max(x, eps) for x in q] + m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)] + def kl(a, b): + return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b)) + return 0.5 * kl(p, m) + 0.5 * kl(q, m) + + +def ks_statistic(x: List[float], y: List[float]) -> float: + if not x or not y: + return 0.0 + x_sorted = sorted(x) + y_sorted = sorted(y) + n = len(x_sorted) + m = len(y_sorted) + i = j = 0 + cdf_x = cdf_y = 0.0 + d = 0.0 + while i < n and j < m: + if x_sorted[i] <= y_sorted[j]: + i += 1 + cdf_x = i / n + else: + j += 1 + cdf_y = j / m + d = max(d, abs(cdf_x - cdf_y)) + return d + + +def lag1_corr(values: List[float]) -> float: + if len(values) < 3: + return 0.0 + x = values[:-1] + y = values[1:] + mean_x = sum(x) / len(x) + mean_y = sum(y) / len(y) + num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y)) + den_x = sum((xi - mean_x) ** 2 for xi in x) + den_y = sum((yi - mean_y) ** 2 for yi in y) + if den_x <= 0 or den_y <= 0: + return 0.0 + return num / math.sqrt(den_x * den_y) + + +def resolve_reference_path(path: str) -> Optional[str]: + if not path: + return None + if any(ch in path for ch in ["*", "?", "["]): + base = Path(path).parent + pat = Path(path).name + matches = sorted(base.glob(pat)) + return str(matches[0]) if matches else None + return str(path) + + def main(): args = parse_args() base_dir = Path(__file__).resolve().parent @@ -63,13 +122,18 @@ def main(): args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out + if args.reference and not Path(args.reference).is_absolute(): + args.reference = str((base_dir / args.reference).resolve()) + ref_path = resolve_reference_path(args.reference) split = load_json(args.split) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] - stats_ref = load_json(args.stats)["mean"] - std_ref = load_json(args.stats)["std"] + stats_json = load_json(args.stats) + stats_ref = stats_json.get("raw_mean", stats_json.get("mean")) + std_ref = stats_json.get("raw_std", stats_json.get("std")) + transforms = stats_json.get("transform", {}) vocab = load_json(args.vocab)["vocab"] vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols} @@ -89,6 +153,8 @@ def main(): except Exception: v = 0.0 update_stats(cont_stats, c, v) + if ref_path: + pass for c in disc_cols: if row[c] not in vocab_sets[c]: disc_invalid[c] += 1 @@ -112,6 +178,81 @@ def main(): "discrete_invalid_counts": disc_invalid, } + # Optional richer metrics using reference data + if ref_path: + ref_cont = {c: [] for c in cont_cols} + ref_disc = {c: {} for c in disc_cols} + gen_cont = {c: [] for c in cont_cols} + gen_disc = {c: {} for c in disc_cols} + + with open_csv(args.generated) as f: + reader = csv.DictReader(f) + for row in reader: + if time_col in row: + row.pop(time_col, None) + for c in cont_cols: + try: + gen_cont[c].append(float(row[c])) + except Exception: + gen_cont[c].append(0.0) + for c in disc_cols: + tok = row[c] + gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1 + + with open_csv(ref_path) as f: + reader = csv.DictReader(f) + for i, row in enumerate(reader): + if time_col in row: + row.pop(time_col, None) + for c in cont_cols: + try: + ref_cont[c].append(float(row[c])) + except Exception: + ref_cont[c].append(0.0) + for c in disc_cols: + tok = row[c] + ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1 + if args.max_rows and i + 1 >= args.max_rows: + break + + # Continuous metrics: KS + quantiles + lag1 correlation + cont_ks = {} + cont_quant = {} + cont_lag1 = {} + for c in cont_cols: + cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c]) + ref_sorted = sorted(ref_cont[c]) + gen_sorted = sorted(gen_cont[c]) + qs = [0.05, 0.25, 0.5, 0.75, 0.95] + def qval(arr, q): + if not arr: + return 0.0 + idx = int(q * (len(arr) - 1)) + return arr[idx] + cont_quant[c] = { + "q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)), + "q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)), + "q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)), + "q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)), + "q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)), + } + cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c])) + + # Discrete metrics: JSD over vocab + disc_jsd = {} + for c in disc_cols: + vocab_vals = list(vocab_sets[c]) + gen_total = sum(gen_disc[c].values()) or 1 + ref_total = sum(ref_disc[c].values()) or 1 + p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals] + q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals] + disc_jsd[c] = js_divergence(p, q) + + report["continuous_ks"] = cont_ks + report["continuous_quantile_diff"] = cont_quant + report["continuous_lag1_diff"] = cont_lag1 + report["discrete_jsd"] = disc_jsd + with open(args.out, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) diff --git a/example/export_samples.py b/example/export_samples.py index 3bfed15..9c46660 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -111,6 +111,7 @@ def main(): vmax = stats.get("max", {}) int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) + transforms = stats.get("transform", {}) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -141,6 +142,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=int(cfg.get("model_time_dim", 64)), + hidden_dim=int(cfg.get("model_hidden_dim", 256)), + num_layers=int(cfg.get("model_num_layers", 1)), + dropout=float(cfg.get("model_dropout", 0.0)), + ff_mult=int(cfg.get("model_ff_mult", 2)), + pos_dim=int(cfg.get("model_pos_dim", 64)), + use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size if use_condition else 0, cond_dim=int(cfg.get("cond_dim", 32)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), @@ -220,6 +228,9 @@ def main(): mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec + for i, c in enumerate(cont_cols): + if transforms.get(c) == "log1p": + x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) # clamp to observed min/max per feature if vmin and vmax: for i, c in enumerate(cont_cols): @@ -246,8 +257,8 @@ def main(): row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" if args.include_time and time_col in header: row[time_col] = str(row_index) - for i, c in enumerate(cont_cols): - val = float(x_cont[b, t, i]) + for i, c in enumerate(cont_cols): + val = float(x_cont[b, t, i]) if int_like.get(c, False): row[c] = str(int(round(val))) else: diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index 506a7ad..4f37356 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -35,11 +35,14 @@ def q_sample_discrete( t: torch.Tensor, mask_tokens: torch.Tensor, max_t: int, + mask_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly mask discrete tokens with a cosine schedule over t.""" bsz = x0.size(0) # cosine schedule: p(0)=0, p(max_t)=1 p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t))) + if mask_scale != 1.0: + p = torch.clamp(p * mask_scale, 0.0, 1.0) p = p.view(bsz, 1, 1) mask = torch.rand_like(x0.float()) < p x_masked = x0.clone() @@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module): disc_vocab_sizes: List[int], time_dim: int = 64, hidden_dim: int = 256, + num_layers: int = 1, + dropout: float = 0.0, + ff_mult: int = 2, + pos_dim: int = 64, + use_pos_embed: bool = True, cond_vocab_size: int = 0, cond_dim: int = 32, use_tanh_eps: bool = False, @@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module): self.time_embed = SinusoidalTimeEmbedding(time_dim) self.use_tanh_eps = use_tanh_eps self.eps_scale = eps_scale + self.pos_dim = pos_dim + self.use_pos_embed = use_pos_embed self.cond_vocab_size = cond_vocab_size self.cond_dim = cond_dim @@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module): disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) - in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0) + pos_dim = pos_dim if use_pos_embed else 0 + in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) self.in_proj = nn.Linear(in_dim, hidden_dim) - self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) + self.backbone = nn.GRU( + hidden_dim, + hidden_dim, + num_layers=num_layers, + dropout=dropout if num_layers > 1 else 0.0, + batch_first=True, + ) + self.post_norm = nn.LayerNorm(hidden_dim) + self.post_ff = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * ff_mult), + nn.GELU(), + nn.Linear(hidden_dim * ff_mult, hidden_dim), + ) self.cont_head = nn.Linear(hidden_dim, cont_dim) self.disc_heads = nn.ModuleList([ @@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module): """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" time_emb = self.time_embed(t) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) + pos_emb = None + if self.use_pos_embed and self.pos_dim > 0: + pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device) disc_embs = [] for i, emb in enumerate(self.disc_embeds): @@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module): cont_feat = self.cont_proj(x_cont) parts = [cont_feat, disc_feat, time_emb] + if pos_emb is not None: + parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) if cond_feat is not None: parts.append(cond_feat) feat = torch.cat(parts, dim=-1) feat = self.in_proj(feat) out, _ = self.backbone(feat) + out = self.post_norm(out) + out = out + self.post_ff(out) eps_pred = self.cont_head(out) if self.use_tanh_eps: eps_pred = torch.tanh(eps_pred) * self.eps_scale logits = [head(out) for head in self.disc_heads] return eps_pred, logits + + @staticmethod + def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor: + pos = torch.arange(seq_len, device=device).float().unsqueeze(1) + div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim)) + pe = torch.zeros(seq_len, dim, device=device) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return pe diff --git a/example/prepare_data.py b/example/prepare_data.py index 26eac22..e0427ac 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -5,7 +5,7 @@ import json from pathlib import Path from typing import Optional -from data_utils import compute_cont_stats, build_disc_stats, load_split +from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms from platform_utils import safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent @@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None): raise SystemExit("no train files found under %s" % str(DATA_GLOB)) data_paths = [safe_path(p) for p in data_paths] - mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) + transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows) + cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) ensure_dir(OUT_STATS.parent) with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: json.dump( { - "mean": mean, - "std": std, - "min": vmin, - "max": vmax, - "int_like": int_like, - "max_decimals": max_decimals, - "max_rows": max_rows, + "mean": cont_stats["mean"], + "std": cont_stats["std"], + "raw_mean": cont_stats["raw_mean"], + "raw_std": cont_stats["raw_std"], + "min": cont_stats["min"], + "max": cont_stats["max"], + "int_like": cont_stats["int_like"], + "max_decimals": cont_stats["max_decimals"], + "transform": cont_stats["transform"], + "skew": cont_stats["skew"], + "max_rows": cont_stats["max_rows"], }, f, indent=2, diff --git a/example/results/cont_stats.json b/example/results/cont_stats.json index 2faae18..c12ab06 100644 --- a/example/results/cont_stats.json +++ b/example/results/cont_stats.json @@ -45,7 +45,7 @@ "P3_PIT01": 668.9722350000003, "P4_HT_FD": -0.00010012580000000082, "P4_HT_LD": 35.41945000099953, - "P4_HT_PO": 35.4085699912002, + "P4_HT_PO": 2.6391372939040414, "P4_LD": 365.3833745803986, "P4_ST_FD": -6.5205999999999635e-06, "P4_ST_GOV": 17801.81294499996, @@ -100,7 +100,7 @@ "P3_PIT01": 1168.1071264424027, "P4_HT_FD": 0.002032582380617592, "P4_HT_LD": 33.212361169253235, - "P4_HT_PO": 31.187825914515162, + "P4_HT_PO": 1.7636196192459512, "P4_LD": 59.736616589045646, "P4_ST_FD": 0.0016428787127432496, "P4_ST_GOV": 1740.5997458128215, @@ -109,6 +109,116 @@ "P4_ST_PT01": 22.459962818146252, "P4_ST_TT01": 24.745939350221477 }, + "raw_mean": { + "P1_B2004": 0.08649086820000026, + "P1_B2016": 1.376161456000001, + "P1_B3004": 396.1861596906018, + "P1_B3005": 1037.372384413793, + "P1_B4002": 32.564872940799994, + "P1_B4005": 65.98190757240047, + "P1_B400B": 1925.0391570245934, + "P1_B4022": 36.28908066800001, + "P1_FCV02Z": 21.744261118400036, + "P1_FCV03D": 57.36123274140044, + "P1_FCV03Z": 58.05084519640002, + "P1_FT01": 184.18615112319728, + "P1_FT01Z": 851.8781750705965, + "P1_FT02": 1255.8572173544069, + "P1_FT02Z": 1925.0210755194114, + "P1_FT03": 269.37285885780574, + "P1_FT03Z": 1037.366172230601, + "P1_LCV01D": 11.228849048599963, + "P1_LCV01Z": 10.991610181600016, + "P1_LIT01": 396.8845311109994, + "P1_PCV01D": 53.80101618419986, + "P1_PCV01Z": 54.646640287199595, + "P1_PCV02Z": 12.017773542800072, + "P1_PIT01": 1.3692859488000075, + "P1_PIT02": 0.44459071260000227, + "P1_TIT01": 35.64255813999988, + "P1_TIT02": 36.44807823060023, + "P2_24Vdc": 28.0280019013999, + "P2_CO_rpm": 54105.64434999997, + "P2_HILout": 712.0588667425922, + "P2_MSD": 763.19324, + "P2_SIT01": 778.7769850000013, + "P2_SIT02": 778.7778935471981, + "P2_VT01": 11.914949448200044, + "P2_VXT02": -3.5267871940000175, + "P2_VXT03": -1.5520904921999914, + "P2_VYT02": 3.796112737600002, + "P2_VYT03": 6.121691697000018, + "P3_FIT01": 1168.2528800000014, + "P3_LCP01D": 4675.465239999989, + "P3_LCV01D": 7445.208720000017, + "P3_LIT01": 13728.982314999852, + "P3_PIT01": 668.9722350000003, + "P4_HT_FD": -0.00010012580000000082, + "P4_HT_LD": 35.41945000099953, + "P4_HT_PO": 35.4085699912002, + "P4_LD": 365.3833745803986, + "P4_ST_FD": -6.5205999999999635e-06, + "P4_ST_GOV": 17801.81294499996, + "P4_ST_LD": 329.83259218199964, + "P4_ST_PO": 330.1079461497967, + "P4_ST_PT01": 10047.679605000127, + "P4_ST_TT01": 27606.860070000155 + }, + "raw_std": { + "P1_B2004": 0.024492489898690458, + "P1_B2016": 0.12949272564759745, + "P1_B3004": 10.16264800653289, + "P1_B3005": 70.85697659109, + "P1_B4002": 0.7578213113008355, + "P1_B4005": 41.80065314991797, + "P1_B400B": 1176.6445547448632, + "P1_B4022": 0.8221115066487089, + "P1_FCV02Z": 39.11843197764177, + "P1_FCV03D": 7.889507447726625, + "P1_FCV03Z": 8.046068905945717, + "P1_FT01": 30.80117031882856, + "P1_FT01Z": 91.2786865433318, + "P1_FT02": 879.7163277334494, + "P1_FT02Z": 1176.6699531305117, + "P1_FT03": 38.18015841964941, + "P1_FT03Z": 70.73100774436428, + "P1_LCV01D": 3.3355655415557597, + "P1_LCV01Z": 3.386332233773545, + "P1_LIT01": 10.578714760104122, + "P1_PCV01D": 19.61567943613885, + "P1_PCV01Z": 19.778754467302086, + "P1_PCV02Z": 0.0048047978931599995, + "P1_PIT01": 0.0776614954053113, + "P1_PIT02": 0.44823231815652304, + "P1_TIT01": 0.5986678527528814, + "P1_TIT02": 1.1892341204521049, + "P2_24Vdc": 0.003208842504097816, + "P2_CO_rpm": 20.575477821507334, + "P2_HILout": 8.178853379908608, + "P2_MSD": 1.0, + "P2_SIT01": 3.894535775667256, + "P2_SIT02": 3.8824770788579395, + "P2_VT01": 0.06812990916670247, + "P2_VXT02": 0.43104157117568803, + "P2_VXT03": 0.26894251958139775, + "P2_VYT02": 0.4610907883207586, + "P2_VYT03": 0.30596429385075474, + "P3_FIT01": 1787.2987693141868, + "P3_LCP01D": 5145.4094261812725, + "P3_LCV01D": 6785.602781765096, + "P3_LIT01": 4060.915441872745, + "P3_PIT01": 1168.1071264424027, + "P4_HT_FD": 0.002032582380617592, + "P4_HT_LD": 33.21236116925323, + "P4_HT_PO": 31.18782591451516, + "P4_LD": 59.736616589045646, + "P4_ST_FD": 0.0016428787127432496, + "P4_ST_GOV": 1740.5997458128213, + "P4_ST_LD": 35.86633288900077, + "P4_ST_PO": 32.375012735256696, + "P4_ST_PT01": 22.45996281814625, + "P4_ST_TT01": 24.745939350221487 + }, "min": { "P1_B2004": 0.03051, "P1_B2016": 0.94729, @@ -329,5 +439,115 @@ "P4_ST_PT01": 2, "P4_ST_TT01": 2 }, + "transform": { + "P1_B2004": "none", + "P1_B2016": "none", + "P1_B3004": "none", + "P1_B3005": "none", + "P1_B4002": "none", + "P1_B4005": "none", + "P1_B400B": "none", + "P1_B4022": "none", + "P1_FCV02Z": "none", + "P1_FCV03D": "none", + "P1_FCV03Z": "none", + "P1_FT01": "none", + "P1_FT01Z": "none", + "P1_FT02": "none", + "P1_FT02Z": "none", + "P1_FT03": "none", + "P1_FT03Z": "none", + "P1_LCV01D": "none", + "P1_LCV01Z": "none", + "P1_LIT01": "none", + "P1_PCV01D": "none", + "P1_PCV01Z": "none", + "P1_PCV02Z": "none", + "P1_PIT01": "none", + "P1_PIT02": "none", + "P1_TIT01": "none", + "P1_TIT02": "none", + "P2_24Vdc": "none", + "P2_CO_rpm": "none", + "P2_HILout": "none", + "P2_MSD": "none", + "P2_SIT01": "none", + "P2_SIT02": "none", + "P2_VT01": "none", + "P2_VXT02": "none", + "P2_VXT03": "none", + "P2_VYT02": "none", + "P2_VYT03": "none", + "P3_FIT01": "none", + "P3_LCP01D": "none", + "P3_LCV01D": "none", + "P3_LIT01": "none", + "P3_PIT01": "none", + "P4_HT_FD": "none", + "P4_HT_LD": "none", + "P4_HT_PO": "log1p", + "P4_LD": "none", + "P4_ST_FD": "none", + "P4_ST_GOV": "none", + "P4_ST_LD": "none", + "P4_ST_PO": "none", + "P4_ST_PT01": "none", + "P4_ST_TT01": "none" + }, + "skew": { + "P1_B2004": -2.876938578031295e-05, + "P1_B2016": 2.014565216651284e-06, + "P1_B3004": 6.625985939357487e-06, + "P1_B3005": -9.917489652810193e-06, + "P1_B4002": 1.4641465884161855e-05, + "P1_B4005": -1.2370279269006856e-05, + "P1_B400B": -1.4116897198097317e-05, + "P1_B4022": 1.1162291352215598e-05, + "P1_FCV02Z": 2.532521501167817e-05, + "P1_FCV03D": 4.2517931711793e-06, + "P1_FCV03Z": 4.301856332440012e-06, + "P1_FT01": -1.3345735264961829e-05, + "P1_FT01Z": -4.2554413198354234e-05, + "P1_FT02": -1.0289230789249066e-05, + "P1_FT02Z": -1.4116856909216661e-05, + "P1_FT03": -4.341090521713463e-06, + "P1_FT03Z": -9.964308983887345e-06, + "P1_LCV01D": 2.541312481372867e-06, + "P1_LCV01Z": 2.5806433622267527e-06, + "P1_LIT01": 7.716120912717401e-06, + "P1_PCV01D": 2.113459306618771e-05, + "P1_PCV01Z": 2.0632832525407433e-05, + "P1_PCV02Z": 4.2639616636720384e-08, + "P1_PIT01": 2.079887220863843e-05, + "P1_PIT02": 5.003507344873546e-05, + "P1_TIT01": 9.553657000925262e-06, + "P1_TIT02": 2.1170357380515215e-05, + "P2_24Vdc": 2.735770838906968e-07, + "P2_CO_rpm": -8.124011608472296e-06, + "P2_HILout": -4.086282393330704e-06, + "P2_MSD": 0.0, + "P2_SIT01": -7.418240348817199e-06, + "P2_SIT02": -7.457826456660247e-06, + "P2_VT01": 1.247484205979928e-07, + "P2_VXT02": 6.53499778855353e-07, + "P2_VXT03": 5.32656056809399e-06, + "P2_VYT02": 9.483158480759724e-07, + "P2_VYT03": 2.128755351566922e-06, + "P3_FIT01": 2.2828575320599336e-05, + "P3_LCP01D": 1.3040993552131866e-05, + "P3_LCV01D": 3.781324885318626e-07, + "P3_LIT01": -7.824733758742217e-06, + "P3_PIT01": 3.210613447428708e-05, + "P4_HT_FD": 9.197840236384403e-05, + "P4_HT_LD": -2.4568845167931336e-08, + "P4_HT_PO": 3.997415489949367e-07, + "P4_LD": -6.253448074273654e-07, + "P4_ST_FD": 2.3472181460829935e-07, + "P4_ST_GOV": 2.494268873407866e-06, + "P4_ST_LD": 1.6692758818969547e-06, + "P4_ST_PO": 2.45129838870492e-06, + "P4_ST_PT01": 1.7637202837434092e-05, + "P4_ST_TT01": -1.9485876142550594e-05 + }, "max_rows": 50000 } \ No newline at end of file diff --git a/example/results/eval.json b/example/results/eval.json index c501c2e..e3041f7 100644 --- a/example/results/eval.json +++ b/example/results/eval.json @@ -233,7 +233,7 @@ }, "P1_B4002": { "mean_abs_err": 0.034608231074983564, - "std_abs_err": 0.03795674780254288 + "std_abs_err": 0.03795674780254299 }, "P1_B4005": { "mean_abs_err": 16.56784507240046, @@ -249,11 +249,11 @@ }, "P1_FCV02Z": { "mean_abs_err": 37.04518503394364, - "std_abs_err": 9.195664924295286 + "std_abs_err": 9.19566492429528 }, "P1_FCV03D": { "mean_abs_err": 0.2813618429629585, - "std_abs_err": 3.31742916874118 + "std_abs_err": 3.317429168741179 }, "P1_FCV03Z": { "mean_abs_err": 2.7769160948375244, @@ -273,7 +273,7 @@ }, "P1_FT02Z": { "mean_abs_err": 389.68675585144206, - "std_abs_err": 237.15423985136158 + "std_abs_err": 237.15423985136135 }, "P1_FT03": { "mean_abs_err": 12.373236123430615, @@ -305,7 +305,7 @@ }, "P1_PCV02Z": { "mean_abs_err": 0.006274240403053355, - "std_abs_err": 0.0139416978463145 + "std_abs_err": 0.013941697846314497 }, "P1_PIT01": { "mean_abs_err": 0.03821283356563221, @@ -317,7 +317,7 @@ }, "P1_TIT01": { "mean_abs_err": 0.13356975062511367, - "std_abs_err": 0.4775895846603686 + "std_abs_err": 0.4775895846603687 }, "P1_TIT02": { "mean_abs_err": 0.4872384686185143, @@ -325,15 +325,15 @@ }, "P2_24Vdc": { "mean_abs_err": 0.0035577079751085705, - "std_abs_err": 0.011396984418792682 + "std_abs_err": 0.011396984418792677 }, "P2_CO_rpm": { "mean_abs_err": 9.448949609344709, - "std_abs_err": 62.711918668665504 + "std_abs_err": 62.71191866866543 }, "P2_HILout": { "mean_abs_err": 5.394922836341834, - "std_abs_err": 23.44018542876042 + "std_abs_err": 23.440185428760422 }, "P2_MSD": { "mean_abs_err": 0.0, @@ -345,11 +345,11 @@ }, "P2_SIT02": { "mean_abs_err": 0.40448069108401796, - "std_abs_err": 14.67317239695784 + "std_abs_err": 14.673172396957842 }, "P2_VT01": { "mean_abs_err": 0.023083168987463765, - "std_abs_err": 0.053772803716143 + "std_abs_err": 0.05377280371614296 }, "P2_VXT02": { "mean_abs_err": 0.1497719303281424, @@ -361,11 +361,11 @@ }, "P2_VYT02": { "mean_abs_err": 0.06072680010000164, - "std_abs_err": 0.5584798619468422 + "std_abs_err": 0.5584798619468421 }, "P2_VYT03": { "mean_abs_err": 0.035759402078149094, - "std_abs_err": 0.6256833854459374 + "std_abs_err": 0.6256833854459373 }, "P3_FIT01": { "mean_abs_err": 1368.8645125781227, @@ -393,11 +393,11 @@ }, "P4_HT_LD": { "mean_abs_err": 3.0017542197495573, - "std_abs_err": 7.306147411731942 + "std_abs_err": 7.306147411731949 }, "P4_HT_PO": { "mean_abs_err": 4.280643741200194, - "std_abs_err": 8.947745679875599 + "std_abs_err": 8.947745679875602 }, "P4_LD": { "mean_abs_err": 34.6203309378206, @@ -425,7 +425,7 @@ }, "P4_ST_TT01": { "mean_abs_err": 32.06905437515161, - "std_abs_err": 19.934628627185333 + "std_abs_err": 19.934628627185322 } }, "discrete_invalid_counts": { @@ -455,5 +455,516 @@ "P3_LL": 0, "P4_HT_PS": 0, "P4_ST_PS": 0 + }, + "continuous_ks": { + "P1_B2004": 0.8106, + "P1_B2016": 0.5790015625, + "P1_B3004": 0.53125, + "P1_B3005": 0.4782921875, + "P1_B4002": 0.8105, + "P1_B4005": 0.79705, + "P1_B400B": 0.595653125, + "P1_B4022": 0.59375, + "P1_FCV02Z": 0.6123046875, + "P1_FCV03D": 0.50390625, + "P1_FCV03Z": 0.61328125, + "P1_FT01": 0.5380859375, + "P1_FT01Z": 0.5390625, + "P1_FT02": 0.6317859375, + "P1_FT02Z": 0.533153125, + "P1_FT03": 0.53125, + "P1_FT03Z": 0.587890625, + "P1_LCV01D": 0.5966796875, + "P1_LCV01Z": 0.611328125, + "P1_LIT01": 0.6025390625, + "P1_PCV01D": 0.5791015625, + "P1_PCV01Z": 0.685546875, + "P1_PCV02Z": 0.568359375, + "P1_PIT01": 0.543871875, + "P1_PIT02": 0.5101953125, + "P1_TIT01": 0.501953125, + "P1_TIT02": 0.6396484375, + "P2_24Vdc": 0.6011625, + "P2_CO_rpm": 0.532503125, + "P2_HILout": 0.524140625, + "P2_MSD": 1.0, + "P2_SIT01": 0.6023390625, + "P2_SIT02": 0.505659375, + "P2_VT01": 0.5654296875, + "P2_VXT02": 0.5615234375, + "P2_VXT03": 0.5126953125, + "P2_VYT02": 0.52734375, + "P2_VYT03": 0.583984375, + "P3_FIT01": 0.52734375, + "P3_LCP01D": 0.568359375, + "P3_LCV01D": 0.529296875, + "P3_LIT01": 0.533203125, + "P3_PIT01": 0.5654296875, + "P4_HT_FD": 0.5009390625, + "P4_HT_LD": 0.609375, + "P4_HT_PO": 0.625, + "P4_LD": 0.6279296875, + "P4_ST_FD": 0.5097921875, + "P4_ST_GOV": 0.577675, + "P4_ST_LD": 0.5361328125, + "P4_ST_PO": 0.51953125, + "P4_ST_PT01": 0.5004765625, + "P4_ST_TT01": 0.595703125 + }, + "continuous_quantile_diff": { + "P1_B2004": { + "q05_diff": 0.0, + "q25_diff": 0.0707, + "q50_diff": 0.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B2016": { + "q05_diff": 0.20477000000000012, + "q25_diff": 0.35475999999999996, + "q50_diff": 0.4335500000000001, + "q75_diff": 0.5400499999999999, + "q95_diff": 0.4267000000000001 + }, + "P1_B3004": { + "q05_diff": 14.80313000000001, + "q25_diff": 19.27872000000002, + "q50_diff": 19.27872000000002, + "q75_diff": 18.877499999999998, + "q95_diff": 18.877499999999998 + }, + "P1_B3005": { + "q05_diff": 107.27929999999992, + "q25_diff": 107.27929999999992, + "q50_diff": 113.12176000000011, + "q75_diff": 113.12176000000011, + "q95_diff": 0.0 + }, + "P1_B4002": { + "q05_diff": 0.0, + "q25_diff": 1.6555000000000035, + "q50_diff": 1.6555000000000035, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B4005": { + "q05_diff": 0.0, + "q25_diff": 100.0, + "q50_diff": 100.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P1_B400B": { + "q05_diff": 8.937850000000001, + "q25_diff": 2803.04775, + "q50_diff": 22.87377000000015, + "q75_diff": 18.10082999999986, + "q95_diff": 12.064449999999852 + }, + "P1_B4022": { + "q05_diff": 0.741570000000003, + "q25_diff": 2.147590000000001, + "q50_diff": 2.5206700000000026, + "q75_diff": 0.883100000000006, + "q95_diff": 0.5329200000000043 + }, + "P1_FCV02Z": { + "q05_diff": 0.015249999999999986, + "q25_diff": 0.015249999999999986, + "q50_diff": 99.09821000000001, + "q75_diff": 99.09821000000001, + "q95_diff": 0.030520000000009873 + }, + "P1_FCV03D": { + "q05_diff": 4.971059999999994, + "q25_diff": 5.407429999999998, + "q50_diff": 5.7296499999999995, + "q75_diff": 16.195880000000002, + "q95_diff": 1.1158599999999979 + }, + "P1_FCV03Z": { + "q05_diff": 5.249020000000002, + "q25_diff": 5.607600000000005, + "q50_diff": 5.729670000000006, + "q75_diff": 16.738889999999998, + "q95_diff": 1.1749300000000034 + }, + "P1_FT01": { + "q05_diff": 120.43077000000001, + "q25_diff": 130.32031, + "q50_diff": 83.73258999999999, + "q75_diff": 78.01056, + "q95_diff": 34.52304000000001 + }, + "P1_FT01Z": { + "q05_diff": 387.17252, + "q25_diff": 407.15109, + "q50_diff": 187.60754000000009, + "q75_diff": 174.5613400000001, + "q95_diff": 75.40979000000004 + }, + "P1_FT02": { + "q05_diff": 1.7166099999999993, + "q25_diff": 1961.32641, + "q50_diff": 31.47144000000003, + "q75_diff": 24.98621000000003, + "q95_diff": 16.78478999999993 + }, + "P1_FT02Z": { + "q05_diff": 8.937850000000001, + "q25_diff": 2803.04775, + "q50_diff": 22.87377000000015, + "q75_diff": 18.10082999999986, + "q95_diff": 12.064449999999852 + }, + "P1_FT03": { + "q05_diff": 57.21861999999999, + "q25_diff": 58.36301999999998, + "q50_diff": 70.76033999999999, + "q75_diff": 69.23453999999998, + "q95_diff": 3.8145700000000033 + }, + "P1_FT03Z": { + "q05_diff": 130.45844, + "q25_diff": 133.06768999999997, + "q50_diff": 120.01207999999997, + "q75_diff": 116.53325999999993, + "q95_diff": 6.377929999999878 + }, + "P1_LCV01D": { + "q05_diff": 4.2898, + "q25_diff": 5.0423599999999995, + "q50_diff": 11.84572, + "q75_diff": 9.60605, + "q95_diff": 4.67375 + }, + "P1_LCV01Z": { + "q05_diff": 4.40978, + "q25_diff": 5.241390000000001, + "q50_diff": 5.95092, + "q75_diff": 9.460459999999998, + "q95_diff": 4.226689999999998 + }, + "P1_LIT01": { + "q05_diff": 34.6062, + "q25_diff": 38.28658999999999, + "q50_diff": 35.42405000000002, + "q75_diff": 33.32828000000001, + "q95_diff": 29.698980000000006 + }, + "P1_PCV01D": { + "q05_diff": 10.95326, + "q25_diff": 15.403260000000003, + "q50_diff": 56.51593, + "q75_diff": 52.88774, + "q95_diff": 46.32521 + }, + "P1_PCV01Z": { + "q05_diff": 10.925290000000004, + "q25_diff": 15.548700000000004, + "q50_diff": 18.63098, + "q75_diff": 51.9104, + "q95_diff": 45.50171 + }, + "P1_PCV02Z": { + "q05_diff": 0.007629999999998915, + "q25_diff": 0.007629999999998915, + "q50_diff": 0.021849999999998815, + "q75_diff": 0.0228900000000003, + "q95_diff": 0.0228900000000003 + }, + "P1_PIT01": { + "q05_diff": 0.27142999999999995, + "q25_diff": 0.3475299999999999, + "q50_diff": 0.3565400000000001, + "q75_diff": 0.32497, + "q95_diff": 0.27395000000000014 + }, + "P1_PIT02": { + "q05_diff": 0.03356999999999999, + "q25_diff": 0.10451999999999997, + "q50_diff": 2.0233700000000003, + "q75_diff": 2.06985, + "q95_diff": 2.06909 + }, + "P1_TIT01": { + "q05_diff": 0.07629999999999626, + "q25_diff": 0.4272399999999976, + "q50_diff": 0.991819999999997, + "q75_diff": 0.5187900000000027, + "q95_diff": 0.09155000000000513 + }, + "P1_TIT02": { + "q05_diff": 0.09155000000000513, + "q25_diff": 0.4577600000000004, + "q50_diff": 1.2207000000000008, + "q75_diff": 3.2806399999999982, + "q95_diff": 1.2207099999999969 + }, + "P2_24Vdc": { + "q05_diff": 0.00946000000000069, + "q25_diff": 0.012100000000000222, + "q50_diff": 0.014739999999999753, + "q75_diff": 0.013829999999998677, + "q95_diff": 0.010680000000000689 + }, + "P2_CO_rpm": { + "q05_diff": 68.91000000000349, + "q25_diff": 88.13999999999942, + "q50_diff": 67.0, + "q75_diff": 54.0, + "q95_diff": 37.0 + }, + "P2_HILout": { + "q05_diff": 21.350099999999998, + "q25_diff": 31.207269999999994, + "q50_diff": 35.31494000000009, + "q75_diff": 21.63695999999993, + "q95_diff": 14.538569999999936 + }, + "P2_MSD": { + "q05_diff": 0.0, + "q25_diff": 0.0, + "q50_diff": 0.0, + "q75_diff": 0.0, + "q95_diff": 0.0 + }, + "P2_SIT01": { + "q05_diff": 12.580000000000041, + "q25_diff": 16.710000000000036, + "q50_diff": 17.850000000000023, + "q75_diff": 16.899999999999977, + "q95_diff": 13.25 + }, + "P2_SIT02": { + "q05_diff": 12.788270000000011, + "q25_diff": 16.65368000000001, + "q50_diff": 14.857610000000022, + "q75_diff": 16.707580000000007, + "q95_diff": 13.430229999999938 + }, + "P2_VT01": { + "q05_diff": 0.015200000000000102, + "q25_diff": 0.046990000000000975, + "q50_diff": 0.13512000000000057, + "q75_diff": 0.07174000000000014, + "q95_diff": 0.03888000000000069 + }, + "P2_VXT02": { + "q05_diff": 0.08729999999999993, + "q25_diff": 0.2607000000000004, + "q50_diff": 0.6693000000000002, + "q75_diff": 0.9367000000000001, + "q95_diff": 0.7728999999999999 + }, + "P2_VXT03": { + "q05_diff": 0.07350000000000012, + "q25_diff": 0.18210000000000015, + "q50_diff": 0.4036000000000002, + "q75_diff": 1.0967, + "q95_diff": 0.9957 + }, + "P2_VYT02": { + "q05_diff": 0.3511000000000002, + "q25_diff": 0.5322, + "q50_diff": 0.9708999999999999, + "q75_diff": 0.6384999999999996, + "q95_diff": 0.4569000000000001 + }, + "P2_VYT03": { + "q05_diff": 0.6910999999999996, + "q25_diff": 0.8129999999999997, + "q50_diff": 0.8028000000000004, + "q75_diff": 0.5364000000000004, + "q95_diff": 0.41800000000000015 + }, + "P3_FIT01": { + "q05_diff": 2.0, + "q25_diff": 4.0, + "q50_diff": 76.0, + "q75_diff": 2735.0, + "q95_diff": 382.0 + }, + "P3_LCP01D": { + "q05_diff": 8.0, + "q25_diff": 56.0, + "q50_diff": 1760.0, + "q75_diff": 4112.0, + "q95_diff": 216.0 + }, + "P3_LCV01D": { + "q05_diff": 16.0, + "q25_diff": 336.0, + "q50_diff": 9488.0, + "q75_diff": 3376.0, + "q95_diff": 1584.0 + }, + "P3_LIT01": { + "q05_diff": 1310.0, + "q25_diff": 4632.0, + "q50_diff": 6346.0, + "q75_diff": 3409.0, + "q95_diff": 473.0 + }, + "P3_PIT01": { + "q05_diff": 2.0, + "q25_diff": 3.0, + "q50_diff": 4.0, + "q75_diff": 2855.0, + "q95_diff": 259.0 + }, + "P4_HT_FD": { + "q05_diff": 0.00863, + "q25_diff": 0.00963, + "q50_diff": 0.007980000000000001, + "q75_diff": 0.00971, + "q95_diff": 0.00881 + }, + "P4_HT_LD": { + "q05_diff": 0.0, + "q25_diff": 0.0, + "q50_diff": 54.74537, + "q75_diff": 14.424189999999996, + "q95_diff": 6.669560000000004 + }, + "P4_HT_PO": { + "q05_diff": 0.0, + "q25_diff": 1.35638, + "q50_diff": 43.402800000000006, + "q75_diff": 15.426150000000007, + "q95_diff": 7.179570000000012 + }, + "P4_LD": { + "q05_diff": 38.17633000000001, + "q25_diff": 89.59051, + "q50_diff": 137.02618, + "q75_diff": 84.97899999999998, + "q95_diff": 39.27954 + }, + "P4_ST_FD": { + "q05_diff": 0.00547, + "q25_diff": 0.00697, + "q50_diff": 0.00664, + "q75_diff": 0.00685, + "q95_diff": 0.0054800000000000005 + }, + "P4_ST_GOV": { + "q05_diff": 2299.0, + "q25_diff": 4178.0, + "q50_diff": 7730.490000000002, + "q75_diff": 7454.919999999998, + "q95_diff": 5812.0 + }, + "P4_ST_LD": { + "q05_diff": 39.333740000000034, + "q25_diff": 76.64203000000003, + "q50_diff": 100.76677999999998, + "q75_diff": 137.83997, + "q95_diff": 105.36016999999998 + }, + "P4_ST_PO": { + "q05_diff": 43.02301, + "q25_diff": 78.10687000000001, + "q50_diff": 136.67437999999999, + "q75_diff": 139.8797, + "q95_diff": 107.81970000000001 + }, + "P4_ST_PT01": { + "q05_diff": 83.0, + "q25_diff": 89.54999999999927, + "q50_diff": 61.399999999999636, + "q75_diff": 104.46999999999935, + "q95_diff": 76.94000000000051 + }, + "P4_ST_TT01": { + "q05_diff": 14.0, + "q25_diff": 48.0, + "q50_diff": 88.0, + "q75_diff": 2.0, + "q95_diff": 2.0 + } + }, + "continuous_lag1_diff": { + "P1_B2004": 0.9718697145204541, + "P1_B2016": 1.0096989619256902, + "P1_B3004": 0.987235023697155, + "P1_B3005": 1.0293153257905971, + "P1_B4002": 1.0101720557598692, + "P1_B4005": 1.0069799798786847, + "P1_B400B": 1.0100831396536611, + "P1_B4022": 1.005260911857099, + "P1_FCV02Z": 1.0208150011699488, + "P1_FCV03D": 0.9677639878142119, + "P1_FCV03Z": 1.0448432960637175, + "P1_FT01": 1.0174834364370597, + "P1_FT01Z": 0.9623633857501661, + "P1_FT02": 0.9492143748470167, + "P1_FT02Z": 1.0248788433686373, + "P1_FT03": 0.9967197043976962, + "P1_FT03Z": 1.0042718034976392, + "P1_LCV01D": 0.9767924561102216, + "P1_LCV01Z": 1.0238429714137112, + "P1_LIT01": 0.994524637441244, + "P1_PCV01D": 0.9731502416003561, + "P1_PCV01Z": 0.9987281311850285, + "P1_PCV02Z": 0.5767325423170566, + "P1_PIT01": 1.0052067307895398, + "P1_PIT02": 1.07011185942615, + "P1_TIT01": 1.0555120205346828, + "P1_TIT02": 0.9917823846962297, + "P2_24Vdc": 0.005526132896643308, + "P2_CO_rpm": 0.40287161666960586, + "P2_HILout": 0.25625640743099676, + "P2_MSD": 0.0, + "P2_SIT01": 0.7010364996672233, + "P2_SIT02": 0.7090423064483174, + "P2_VT01": 0.8506494578213937, + "P2_VXT02": 0.8094431207350834, + "P2_VXT03": 0.8417667674176789, + "P2_VYT02": 0.8415060810530584, + "P2_VYT03": 0.8550842087621501, + "P3_FIT01": 0.9816722571345234, + "P3_LCP01D": 0.9961743411760093, + "P3_LCV01D": 0.9663083457613558, + "P3_LIT01": 1.0191146983629609, + "P3_PIT01": 0.9827441734544354, + "P4_HT_FD": 0.26359661930055545, + "P4_HT_LD": 1.0050543113208599, + "P4_HT_PO": 1.0522741078670352, + "P4_LD": 0.9692264478458912, + "P4_ST_FD": 0.37863299269548845, + "P4_ST_GOV": 1.0392419101031738, + "P4_ST_LD": 0.9837987258388408, + "P4_ST_PO": 1.0089254379868162, + "P4_ST_PT01": 1.0237355221468971, + "P4_ST_TT01": 0.9905047944110966 + }, + "discrete_jsd": { + "P1_FCV01D": 0.10738201625126076, + "P1_FCV01Z": 0.24100377687195113, + "P1_FCV02D": 0.059498960127610634, + "P1_PCV02D": 0.0, + "P1_PP01AD": 0.0, + "P1_PP01AR": 0.0, + "P1_PP01BD": 0.0, + "P1_PP01BR": 0.0, + "P1_PP02D": 0.0, + "P1_PP02R": 0.0, + "P1_STSP": 0.0, + "P2_ASD": 0.0, + "P2_AutoGO": 0.0, + "P2_Emerg": 0.0, + "P2_ManualGO": 0.0, + "P2_OnOff": 0.0, + "P2_RTR": 0.0, + "P2_TripEx": 0.0, + "P2_VTR01": 0.0, + "P2_VTR02": 0.0, + "P2_VTR03": 0.0, + "P2_VTR04": 0.0, + "P3_LH": 0.0, + "P3_LL": 0.0, + "P4_HT_PS": 0.0, + "P4_ST_PS": 0.0 } } \ No newline at end of file diff --git a/example/run_pipeline.py b/example/run_pipeline.py index 0f1fd69..c93974d 100644 --- a/example/run_pipeline.py +++ b/example/run_pipeline.py @@ -48,6 +48,8 @@ def main(): seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) clip_k = cfg.get("clip_k", 5.0) + data_glob = cfg.get("data_glob", "") + data_path = cfg.get("data_path", "") run([sys.executable, str(base_dir / "prepare_data.py")]) run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device]) run( @@ -70,7 +72,11 @@ def main(): "--use-ema", ] ) - run([sys.executable, str(base_dir / "evaluate_generated.py")]) + ref = data_glob if data_glob else data_path + if ref: + run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) + else: + run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "plot_loss.py")]) diff --git a/example/sample.py b/example/sample.py index 14a5863..1b5f85d 100755 --- a/example/sample.py +++ b/example/sample.py @@ -47,6 +47,13 @@ def main(): cond_dim = int(cfg.get("cond_dim", 32)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) + model_time_dim = int(cfg.get("model_time_dim", 64)) + model_hidden_dim = int(cfg.get("model_hidden_dim", 256)) + model_num_layers = int(cfg.get("model_num_layers", 1)) + model_dropout = float(cfg.get("model_dropout", 0.0)) + model_ff_mult = int(cfg.get("model_ff_mult", 2)) + model_pos_dim = int(cfg.get("model_pos_dim", 64)) + model_use_pos = bool(cfg.get("model_use_pos_embed", True)) split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") @@ -67,6 +74,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=model_time_dim, + hidden_dim=model_hidden_dim, + num_layers=model_num_layers, + dropout=model_dropout, + ff_mult=model_ff_mult, + pos_dim=model_pos_dim, + use_pos_embed=model_use_pos, cond_vocab_size=cond_vocab_size, cond_dim=cond_dim, use_tanh_eps=use_tanh_eps, diff --git a/example/train.py b/example/train.py index 13699c7..467eb76 100755 --- a/example/train.py +++ b/example/train.py @@ -49,8 +49,17 @@ DEFAULTS = { "use_condition": True, "condition_type": "file_id", "cond_dim": 32, - "use_tanh_eps": True, + "use_tanh_eps": False, "eps_scale": 1.0, + "model_time_dim": 128, + "model_hidden_dim": 512, + "model_num_layers": 2, + "model_dropout": 0.1, + "model_ff_mult": 2, + "model_pos_dim": 64, + "model_use_pos_embed": True, + "disc_mask_scale": 0.9, + "shuffle_buffer": 256, } @@ -144,6 +153,7 @@ def main(): stats = load_json(config["stats_path"]) mean = stats["mean"] std = stats["std"] + transforms = stats.get("transform", {}) vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] @@ -164,6 +174,13 @@ def main(): model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, + time_dim=int(config.get("model_time_dim", 64)), + hidden_dim=int(config.get("model_hidden_dim", 256)), + num_layers=int(config.get("model_num_layers", 1)), + dropout=float(config.get("model_dropout", 0.0)), + ff_mult=int(config.get("model_ff_mult", 2)), + pos_dim=int(config.get("model_pos_dim", 64)), + use_pos_embed=bool(config.get("model_use_pos_embed", True)), cond_vocab_size=cond_vocab_size, cond_dim=int(config.get("cond_dim", 32)), use_tanh_eps=bool(config.get("use_tanh_eps", False)), @@ -198,6 +215,8 @@ def main(): seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), return_file_id=use_condition, + transforms=transforms, + shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): if use_condition: @@ -215,7 +234,13 @@ def main(): x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) - x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) + x_disc_t, mask = q_sample_discrete( + x_disc, + t, + mask_tokens, + int(config["timesteps"]), + mask_scale=float(config.get("disc_mask_scale", 1.0)), + ) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)