连续型特征在时许相关性上的不足
This commit is contained in:
@@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto
|
||||
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
|
||||
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
|
||||
- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json).
|
||||
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
|
||||
- The script only samples the first 5000 rows to stay fast.
|
||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
"vocab_path": "./results/disc_vocab.json",
|
||||
"out_dir": "./results",
|
||||
"device": "auto",
|
||||
"timesteps": 400,
|
||||
"timesteps": 600,
|
||||
"batch_size": 128,
|
||||
"seq_len": 128,
|
||||
"epochs": 8,
|
||||
"max_batches": 3000,
|
||||
"epochs": 10,
|
||||
"max_batches": 4000,
|
||||
"lambda": 0.5,
|
||||
"lr": 0.0005,
|
||||
"seed": 1337,
|
||||
@@ -23,8 +23,17 @@
|
||||
"use_condition": true,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": true,
|
||||
"use_tanh_eps": false,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
"sample_batch_size": 8,
|
||||
"sample_seq_len": 128
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||
yield row
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
def _stream_basic_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
|
||||
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
|
||||
count = 0
|
||||
):
|
||||
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
m3 = {c: 0.0 for c in cont_cols}
|
||||
vmin = {c: float("inf") for c in cont_cols}
|
||||
vmax = {c: float("-inf") for c in cont_cols}
|
||||
int_like = {c: True for c in cont_cols}
|
||||
max_decimals = {c: 0 for c in cont_cols}
|
||||
all_pos = {c: True for c in cont_cols}
|
||||
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
count += 1
|
||||
for c in cont_cols:
|
||||
raw = row[c]
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
x = float(raw)
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / count
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
if x <= 0:
|
||||
all_pos[c] = False
|
||||
if x < vmin[c]:
|
||||
vmin[c] = x
|
||||
if x > vmax[c]:
|
||||
vmax[c] = x
|
||||
if int_like[c] and abs(x - round(x)) > 1e-9:
|
||||
int_like[c] = False
|
||||
# track decimal places from raw string if possible
|
||||
if "e" in raw or "E" in raw:
|
||||
# scientific notation, skip precision inference
|
||||
continue
|
||||
if "." in raw:
|
||||
if "e" not in raw and "E" not in raw and "." in raw:
|
||||
dec = raw.split(".", 1)[1].rstrip("0")
|
||||
if len(dec) > max_decimals[c]:
|
||||
max_decimals[c] = len(dec)
|
||||
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
delta_n = delta / n
|
||||
term1 = delta * delta_n * (n - 1)
|
||||
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
|
||||
m2[c] += term1
|
||||
mean[c] += delta_n
|
||||
count[c] = n
|
||||
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
# finalize std/skew
|
||||
std = {}
|
||||
skew = {}
|
||||
for c in cont_cols:
|
||||
if count > 1:
|
||||
var = m2[c] / (count - 1)
|
||||
n = count[c]
|
||||
if n > 1:
|
||||
var = m2[c] / (n - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
# replace infs if column had no valid values
|
||||
if n > 2 and m2[c] > 0:
|
||||
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
|
||||
else:
|
||||
skew[c] = 0.0
|
||||
|
||||
for c in cont_cols:
|
||||
if vmin[c] == float("inf"):
|
||||
vmin[c] = 0.0
|
||||
if vmax[c] == float("-inf"):
|
||||
vmax[c] = 0.0
|
||||
return mean, std, vmin, vmax, int_like, max_decimals
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"m2": m2,
|
||||
"m3": m3,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"skew": skew,
|
||||
"all_pos": all_pos,
|
||||
}
|
||||
|
||||
|
||||
def choose_cont_transforms(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
skew_threshold: float = 1.5,
|
||||
range_ratio_threshold: float = 1e3,
|
||||
):
|
||||
"""Pick per-column transform (currently log1p or none) based on skew/range."""
|
||||
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
transforms = {}
|
||||
for c in cont_cols:
|
||||
if not stats["all_pos"][c]:
|
||||
transforms[c] = "none"
|
||||
continue
|
||||
skew = abs(stats["skew"][c])
|
||||
vmin = stats["min"][c]
|
||||
vmax = stats["max"][c]
|
||||
ratio = (vmax / vmin) if vmin > 0 else 0.0
|
||||
if skew >= skew_threshold or ratio >= range_ratio_threshold:
|
||||
transforms[c] = "log1p"
|
||||
else:
|
||||
transforms[c] = "none"
|
||||
return transforms, stats
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
|
||||
# First pass (raw) for metadata and raw mean/std
|
||||
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
|
||||
# Optional transform selection
|
||||
if transforms is None:
|
||||
transforms = {c: "none" for c in cont_cols}
|
||||
|
||||
# Second pass for transformed mean/std
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in cont_cols:
|
||||
raw_val = row[c]
|
||||
if raw_val is None or raw_val == "":
|
||||
continue
|
||||
x = float(raw_val)
|
||||
if transforms.get(c) == "log1p":
|
||||
if x < 0:
|
||||
x = 0.0
|
||||
x = math.log1p(x)
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / n
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
count[c] = n
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
std = {}
|
||||
for c in cont_cols:
|
||||
if count[c] > 1:
|
||||
var = m2[c] / (count[c] - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"raw_mean": raw["mean"],
|
||||
"raw_std": raw["std"],
|
||||
"min": raw["min"],
|
||||
"max": raw["max"],
|
||||
"int_like": raw["int_like"],
|
||||
"max_decimals": raw["max_decimals"],
|
||||
"transform": transforms,
|
||||
"skew": raw["skew"],
|
||||
"all_pos": raw["all_pos"],
|
||||
"max_rows": max_rows,
|
||||
}
|
||||
|
||||
|
||||
def build_vocab(
|
||||
@@ -130,8 +243,19 @@ def build_disc_stats(
|
||||
return vocab, top_token
|
||||
|
||||
|
||||
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||
def normalize_cont(
|
||||
x,
|
||||
cont_cols: List[str],
|
||||
mean: Dict[str, float],
|
||||
std: Dict[str, float],
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
import torch
|
||||
|
||||
if transforms:
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
return (x - mean_t) / std_t
|
||||
@@ -148,19 +272,34 @@ def windowed_batches(
|
||||
seq_len: int,
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
shuffle_buffer: int = 0,
|
||||
):
|
||||
import torch
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
buffer = []
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
def flush_seq():
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
|
||||
def flush_seq(file_id: int):
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
|
||||
if len(seq_cont) == seq_len:
|
||||
if shuffle_buffer and shuffle_buffer > 0:
|
||||
buffer.append((list(seq_cont), list(seq_disc), file_id))
|
||||
if len(buffer) >= shuffle_buffer:
|
||||
idx = random.randrange(len(buffer))
|
||||
seq_c, seq_d, seq_f = buffer.pop(idx)
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
else:
|
||||
batch_cont.append(seq_cont)
|
||||
batch_disc.append(seq_disc)
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
@@ -173,13 +312,11 @@ def windowed_batches(
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
flush_seq(file_id)
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
@@ -195,4 +332,29 @@ def windowed_batches(
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
# Flush any remaining buffered sequences
|
||||
if shuffle_buffer and buffer:
|
||||
random.shuffle(buffer)
|
||||
for seq_c, seq_d, seq_f in buffer:
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
if len(batch_cont) == batch_size:
|
||||
import torch
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
else:
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
|
||||
# Drop last partial batch for simplicity
|
||||
|
||||
@@ -5,8 +5,9 @@ import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
|
||||
|
||||
def load_json(path: str) -> Dict:
|
||||
@@ -28,6 +29,8 @@ def parse_args():
|
||||
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
|
||||
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -55,6 +58,62 @@ def finalize_stats(stats):
|
||||
return out
|
||||
|
||||
|
||||
def js_divergence(p, q, eps: float = 1e-12) -> float:
|
||||
p = [max(x, eps) for x in p]
|
||||
q = [max(x, eps) for x in q]
|
||||
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
|
||||
def kl(a, b):
|
||||
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
|
||||
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
|
||||
|
||||
|
||||
def ks_statistic(x: List[float], y: List[float]) -> float:
|
||||
if not x or not y:
|
||||
return 0.0
|
||||
x_sorted = sorted(x)
|
||||
y_sorted = sorted(y)
|
||||
n = len(x_sorted)
|
||||
m = len(y_sorted)
|
||||
i = j = 0
|
||||
cdf_x = cdf_y = 0.0
|
||||
d = 0.0
|
||||
while i < n and j < m:
|
||||
if x_sorted[i] <= y_sorted[j]:
|
||||
i += 1
|
||||
cdf_x = i / n
|
||||
else:
|
||||
j += 1
|
||||
cdf_y = j / m
|
||||
d = max(d, abs(cdf_x - cdf_y))
|
||||
return d
|
||||
|
||||
|
||||
def lag1_corr(values: List[float]) -> float:
|
||||
if len(values) < 3:
|
||||
return 0.0
|
||||
x = values[:-1]
|
||||
y = values[1:]
|
||||
mean_x = sum(x) / len(x)
|
||||
mean_y = sum(y) / len(y)
|
||||
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
|
||||
den_x = sum((xi - mean_x) ** 2 for xi in x)
|
||||
den_y = sum((yi - mean_y) ** 2 for yi in y)
|
||||
if den_x <= 0 or den_y <= 0:
|
||||
return 0.0
|
||||
return num / math.sqrt(den_x * den_y)
|
||||
|
||||
|
||||
def resolve_reference_path(path: str) -> Optional[str]:
|
||||
if not path:
|
||||
return None
|
||||
if any(ch in path for ch in ["*", "?", "["]):
|
||||
base = Path(path).parent
|
||||
pat = Path(path).name
|
||||
matches = sorted(base.glob(pat))
|
||||
return str(matches[0]) if matches else None
|
||||
return str(path)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
@@ -63,13 +122,18 @@ def main():
|
||||
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
|
||||
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
|
||||
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
|
||||
if args.reference and not Path(args.reference).is_absolute():
|
||||
args.reference = str((base_dir / args.reference).resolve())
|
||||
ref_path = resolve_reference_path(args.reference)
|
||||
split = load_json(args.split)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats_ref = load_json(args.stats)["mean"]
|
||||
std_ref = load_json(args.stats)["std"]
|
||||
stats_json = load_json(args.stats)
|
||||
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
|
||||
std_ref = stats_json.get("raw_std", stats_json.get("std"))
|
||||
transforms = stats_json.get("transform", {})
|
||||
vocab = load_json(args.vocab)["vocab"]
|
||||
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||
|
||||
@@ -89,6 +153,8 @@ def main():
|
||||
except Exception:
|
||||
v = 0.0
|
||||
update_stats(cont_stats, c, v)
|
||||
if ref_path:
|
||||
pass
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
disc_invalid[c] += 1
|
||||
@@ -112,6 +178,81 @@ def main():
|
||||
"discrete_invalid_counts": disc_invalid,
|
||||
}
|
||||
|
||||
# Optional richer metrics using reference data
|
||||
if ref_path:
|
||||
ref_cont = {c: [] for c in cont_cols}
|
||||
ref_disc = {c: {} for c in disc_cols}
|
||||
gen_cont = {c: [] for c in cont_cols}
|
||||
gen_disc = {c: {} for c in disc_cols}
|
||||
|
||||
with open_csv(args.generated) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
gen_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
gen_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
||||
|
||||
with open_csv(ref_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
ref_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
if args.max_rows and i + 1 >= args.max_rows:
|
||||
break
|
||||
|
||||
# Continuous metrics: KS + quantiles + lag1 correlation
|
||||
cont_ks = {}
|
||||
cont_quant = {}
|
||||
cont_lag1 = {}
|
||||
for c in cont_cols:
|
||||
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
|
||||
ref_sorted = sorted(ref_cont[c])
|
||||
gen_sorted = sorted(gen_cont[c])
|
||||
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
|
||||
def qval(arr, q):
|
||||
if not arr:
|
||||
return 0.0
|
||||
idx = int(q * (len(arr) - 1))
|
||||
return arr[idx]
|
||||
cont_quant[c] = {
|
||||
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
|
||||
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
|
||||
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
|
||||
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
|
||||
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
|
||||
}
|
||||
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
|
||||
|
||||
# Discrete metrics: JSD over vocab
|
||||
disc_jsd = {}
|
||||
for c in disc_cols:
|
||||
vocab_vals = list(vocab_sets[c])
|
||||
gen_total = sum(gen_disc[c].values()) or 1
|
||||
ref_total = sum(ref_disc[c].values()) or 1
|
||||
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
|
||||
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
|
||||
disc_jsd[c] = js_divergence(p, q)
|
||||
|
||||
report["continuous_ks"] = cont_ks
|
||||
report["continuous_quantile_diff"] = cont_quant
|
||||
report["continuous_lag1_diff"] = cont_lag1
|
||||
report["discrete_jsd"] = disc_jsd
|
||||
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ def main():
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
@@ -141,6 +142,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||
num_layers=int(cfg.get("model_num_layers", 1)),
|
||||
dropout=float(cfg.get("model_dropout", 0.0)),
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
@@ -220,6 +228,9 @@ def main():
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
# clamp to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
|
||||
@@ -35,11 +35,14 @@ def q_sample_discrete(
|
||||
t: torch.Tensor,
|
||||
mask_tokens: torch.Tensor,
|
||||
max_t: int,
|
||||
mask_scale: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Randomly mask discrete tokens with a cosine schedule over t."""
|
||||
bsz = x0.size(0)
|
||||
# cosine schedule: p(0)=0, p(max_t)=1
|
||||
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
|
||||
if mask_scale != 1.0:
|
||||
p = torch.clamp(p * mask_scale, 0.0, 1.0)
|
||||
p = p.view(bsz, 1, 1)
|
||||
mask = torch.rand_like(x0.float()) < p
|
||||
x_masked = x0.clone()
|
||||
@@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_vocab_sizes: List[int],
|
||||
time_dim: int = 64,
|
||||
hidden_dim: int = 256,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
ff_mult: int = 2,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
||||
self.use_tanh_eps = use_tanh_eps
|
||||
self.eps_scale = eps_scale
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
@@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||
|
||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.post_norm = nn.LayerNorm(hidden_dim)
|
||||
self.post_ff = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim * ff_mult),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim * ff_mult, hidden_dim),
|
||||
)
|
||||
|
||||
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
||||
self.disc_heads = nn.ModuleList([
|
||||
@@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module):
|
||||
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||
time_emb = self.time_embed(t)
|
||||
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
pos_emb = None
|
||||
if self.use_pos_embed and self.pos_dim > 0:
|
||||
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
|
||||
|
||||
disc_embs = []
|
||||
for i, emb in enumerate(self.disc_embeds):
|
||||
@@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module):
|
||||
|
||||
cont_feat = self.cont_proj(x_cont)
|
||||
parts = [cont_feat, disc_feat, time_emb]
|
||||
if pos_emb is not None:
|
||||
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
||||
if cond_feat is not None:
|
||||
parts.append(cond_feat)
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
out, _ = self.backbone(feat)
|
||||
out = self.post_norm(out)
|
||||
out = out + self.post_ff(out)
|
||||
|
||||
eps_pred = self.cont_head(out)
|
||||
if self.use_tanh_eps:
|
||||
eps_pred = torch.tanh(eps_pred) * self.eps_scale
|
||||
logits = [head(out) for head in self.disc_heads]
|
||||
return eps_pred, logits
|
||||
|
||||
@staticmethod
|
||||
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
||||
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
||||
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
||||
pe = torch.zeros(seq_len, dim, device=device)
|
||||
pe[:, 0::2] = torch.sin(pos * div)
|
||||
pe[:, 1::2] = torch.cos(pos * div)
|
||||
return pe
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None):
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"max_rows": max_rows,
|
||||
"mean": cont_stats["mean"],
|
||||
"std": cont_stats["std"],
|
||||
"raw_mean": cont_stats["raw_mean"],
|
||||
"raw_std": cont_stats["raw_std"],
|
||||
"min": cont_stats["min"],
|
||||
"max": cont_stats["max"],
|
||||
"int_like": cont_stats["int_like"],
|
||||
"max_decimals": cont_stats["max_decimals"],
|
||||
"transform": cont_stats["transform"],
|
||||
"skew": cont_stats["skew"],
|
||||
"max_rows": cont_stats["max_rows"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
"P3_PIT01": 668.9722350000003,
|
||||
"P4_HT_FD": -0.00010012580000000082,
|
||||
"P4_HT_LD": 35.41945000099953,
|
||||
"P4_HT_PO": 35.4085699912002,
|
||||
"P4_HT_PO": 2.6391372939040414,
|
||||
"P4_LD": 365.3833745803986,
|
||||
"P4_ST_FD": -6.5205999999999635e-06,
|
||||
"P4_ST_GOV": 17801.81294499996,
|
||||
@@ -100,7 +100,7 @@
|
||||
"P3_PIT01": 1168.1071264424027,
|
||||
"P4_HT_FD": 0.002032582380617592,
|
||||
"P4_HT_LD": 33.212361169253235,
|
||||
"P4_HT_PO": 31.187825914515162,
|
||||
"P4_HT_PO": 1.7636196192459512,
|
||||
"P4_LD": 59.736616589045646,
|
||||
"P4_ST_FD": 0.0016428787127432496,
|
||||
"P4_ST_GOV": 1740.5997458128215,
|
||||
@@ -109,6 +109,116 @@
|
||||
"P4_ST_PT01": 22.459962818146252,
|
||||
"P4_ST_TT01": 24.745939350221477
|
||||
},
|
||||
"raw_mean": {
|
||||
"P1_B2004": 0.08649086820000026,
|
||||
"P1_B2016": 1.376161456000001,
|
||||
"P1_B3004": 396.1861596906018,
|
||||
"P1_B3005": 1037.372384413793,
|
||||
"P1_B4002": 32.564872940799994,
|
||||
"P1_B4005": 65.98190757240047,
|
||||
"P1_B400B": 1925.0391570245934,
|
||||
"P1_B4022": 36.28908066800001,
|
||||
"P1_FCV02Z": 21.744261118400036,
|
||||
"P1_FCV03D": 57.36123274140044,
|
||||
"P1_FCV03Z": 58.05084519640002,
|
||||
"P1_FT01": 184.18615112319728,
|
||||
"P1_FT01Z": 851.8781750705965,
|
||||
"P1_FT02": 1255.8572173544069,
|
||||
"P1_FT02Z": 1925.0210755194114,
|
||||
"P1_FT03": 269.37285885780574,
|
||||
"P1_FT03Z": 1037.366172230601,
|
||||
"P1_LCV01D": 11.228849048599963,
|
||||
"P1_LCV01Z": 10.991610181600016,
|
||||
"P1_LIT01": 396.8845311109994,
|
||||
"P1_PCV01D": 53.80101618419986,
|
||||
"P1_PCV01Z": 54.646640287199595,
|
||||
"P1_PCV02Z": 12.017773542800072,
|
||||
"P1_PIT01": 1.3692859488000075,
|
||||
"P1_PIT02": 0.44459071260000227,
|
||||
"P1_TIT01": 35.64255813999988,
|
||||
"P1_TIT02": 36.44807823060023,
|
||||
"P2_24Vdc": 28.0280019013999,
|
||||
"P2_CO_rpm": 54105.64434999997,
|
||||
"P2_HILout": 712.0588667425922,
|
||||
"P2_MSD": 763.19324,
|
||||
"P2_SIT01": 778.7769850000013,
|
||||
"P2_SIT02": 778.7778935471981,
|
||||
"P2_VT01": 11.914949448200044,
|
||||
"P2_VXT02": -3.5267871940000175,
|
||||
"P2_VXT03": -1.5520904921999914,
|
||||
"P2_VYT02": 3.796112737600002,
|
||||
"P2_VYT03": 6.121691697000018,
|
||||
"P3_FIT01": 1168.2528800000014,
|
||||
"P3_LCP01D": 4675.465239999989,
|
||||
"P3_LCV01D": 7445.208720000017,
|
||||
"P3_LIT01": 13728.982314999852,
|
||||
"P3_PIT01": 668.9722350000003,
|
||||
"P4_HT_FD": -0.00010012580000000082,
|
||||
"P4_HT_LD": 35.41945000099953,
|
||||
"P4_HT_PO": 35.4085699912002,
|
||||
"P4_LD": 365.3833745803986,
|
||||
"P4_ST_FD": -6.5205999999999635e-06,
|
||||
"P4_ST_GOV": 17801.81294499996,
|
||||
"P4_ST_LD": 329.83259218199964,
|
||||
"P4_ST_PO": 330.1079461497967,
|
||||
"P4_ST_PT01": 10047.679605000127,
|
||||
"P4_ST_TT01": 27606.860070000155
|
||||
},
|
||||
"raw_std": {
|
||||
"P1_B2004": 0.024492489898690458,
|
||||
"P1_B2016": 0.12949272564759745,
|
||||
"P1_B3004": 10.16264800653289,
|
||||
"P1_B3005": 70.85697659109,
|
||||
"P1_B4002": 0.7578213113008355,
|
||||
"P1_B4005": 41.80065314991797,
|
||||
"P1_B400B": 1176.6445547448632,
|
||||
"P1_B4022": 0.8221115066487089,
|
||||
"P1_FCV02Z": 39.11843197764177,
|
||||
"P1_FCV03D": 7.889507447726625,
|
||||
"P1_FCV03Z": 8.046068905945717,
|
||||
"P1_FT01": 30.80117031882856,
|
||||
"P1_FT01Z": 91.2786865433318,
|
||||
"P1_FT02": 879.7163277334494,
|
||||
"P1_FT02Z": 1176.6699531305117,
|
||||
"P1_FT03": 38.18015841964941,
|
||||
"P1_FT03Z": 70.73100774436428,
|
||||
"P1_LCV01D": 3.3355655415557597,
|
||||
"P1_LCV01Z": 3.386332233773545,
|
||||
"P1_LIT01": 10.578714760104122,
|
||||
"P1_PCV01D": 19.61567943613885,
|
||||
"P1_PCV01Z": 19.778754467302086,
|
||||
"P1_PCV02Z": 0.0048047978931599995,
|
||||
"P1_PIT01": 0.0776614954053113,
|
||||
"P1_PIT02": 0.44823231815652304,
|
||||
"P1_TIT01": 0.5986678527528814,
|
||||
"P1_TIT02": 1.1892341204521049,
|
||||
"P2_24Vdc": 0.003208842504097816,
|
||||
"P2_CO_rpm": 20.575477821507334,
|
||||
"P2_HILout": 8.178853379908608,
|
||||
"P2_MSD": 1.0,
|
||||
"P2_SIT01": 3.894535775667256,
|
||||
"P2_SIT02": 3.8824770788579395,
|
||||
"P2_VT01": 0.06812990916670247,
|
||||
"P2_VXT02": 0.43104157117568803,
|
||||
"P2_VXT03": 0.26894251958139775,
|
||||
"P2_VYT02": 0.4610907883207586,
|
||||
"P2_VYT03": 0.30596429385075474,
|
||||
"P3_FIT01": 1787.2987693141868,
|
||||
"P3_LCP01D": 5145.4094261812725,
|
||||
"P3_LCV01D": 6785.602781765096,
|
||||
"P3_LIT01": 4060.915441872745,
|
||||
"P3_PIT01": 1168.1071264424027,
|
||||
"P4_HT_FD": 0.002032582380617592,
|
||||
"P4_HT_LD": 33.21236116925323,
|
||||
"P4_HT_PO": 31.18782591451516,
|
||||
"P4_LD": 59.736616589045646,
|
||||
"P4_ST_FD": 0.0016428787127432496,
|
||||
"P4_ST_GOV": 1740.5997458128213,
|
||||
"P4_ST_LD": 35.86633288900077,
|
||||
"P4_ST_PO": 32.375012735256696,
|
||||
"P4_ST_PT01": 22.45996281814625,
|
||||
"P4_ST_TT01": 24.745939350221487
|
||||
},
|
||||
"min": {
|
||||
"P1_B2004": 0.03051,
|
||||
"P1_B2016": 0.94729,
|
||||
@@ -329,5 +439,115 @@
|
||||
"P4_ST_PT01": 2,
|
||||
"P4_ST_TT01": 2
|
||||
},
|
||||
"transform": {
|
||||
"P1_B2004": "none",
|
||||
"P1_B2016": "none",
|
||||
"P1_B3004": "none",
|
||||
"P1_B3005": "none",
|
||||
"P1_B4002": "none",
|
||||
"P1_B4005": "none",
|
||||
"P1_B400B": "none",
|
||||
"P1_B4022": "none",
|
||||
"P1_FCV02Z": "none",
|
||||
"P1_FCV03D": "none",
|
||||
"P1_FCV03Z": "none",
|
||||
"P1_FT01": "none",
|
||||
"P1_FT01Z": "none",
|
||||
"P1_FT02": "none",
|
||||
"P1_FT02Z": "none",
|
||||
"P1_FT03": "none",
|
||||
"P1_FT03Z": "none",
|
||||
"P1_LCV01D": "none",
|
||||
"P1_LCV01Z": "none",
|
||||
"P1_LIT01": "none",
|
||||
"P1_PCV01D": "none",
|
||||
"P1_PCV01Z": "none",
|
||||
"P1_PCV02Z": "none",
|
||||
"P1_PIT01": "none",
|
||||
"P1_PIT02": "none",
|
||||
"P1_TIT01": "none",
|
||||
"P1_TIT02": "none",
|
||||
"P2_24Vdc": "none",
|
||||
"P2_CO_rpm": "none",
|
||||
"P2_HILout": "none",
|
||||
"P2_MSD": "none",
|
||||
"P2_SIT01": "none",
|
||||
"P2_SIT02": "none",
|
||||
"P2_VT01": "none",
|
||||
"P2_VXT02": "none",
|
||||
"P2_VXT03": "none",
|
||||
"P2_VYT02": "none",
|
||||
"P2_VYT03": "none",
|
||||
"P3_FIT01": "none",
|
||||
"P3_LCP01D": "none",
|
||||
"P3_LCV01D": "none",
|
||||
"P3_LIT01": "none",
|
||||
"P3_PIT01": "none",
|
||||
"P4_HT_FD": "none",
|
||||
"P4_HT_LD": "none",
|
||||
"P4_HT_PO": "log1p",
|
||||
"P4_LD": "none",
|
||||
"P4_ST_FD": "none",
|
||||
"P4_ST_GOV": "none",
|
||||
"P4_ST_LD": "none",
|
||||
"P4_ST_PO": "none",
|
||||
"P4_ST_PT01": "none",
|
||||
"P4_ST_TT01": "none"
|
||||
},
|
||||
"skew": {
|
||||
"P1_B2004": -2.876938578031295e-05,
|
||||
"P1_B2016": 2.014565216651284e-06,
|
||||
"P1_B3004": 6.625985939357487e-06,
|
||||
"P1_B3005": -9.917489652810193e-06,
|
||||
"P1_B4002": 1.4641465884161855e-05,
|
||||
"P1_B4005": -1.2370279269006856e-05,
|
||||
"P1_B400B": -1.4116897198097317e-05,
|
||||
"P1_B4022": 1.1162291352215598e-05,
|
||||
"P1_FCV02Z": 2.532521501167817e-05,
|
||||
"P1_FCV03D": 4.2517931711793e-06,
|
||||
"P1_FCV03Z": 4.301856332440012e-06,
|
||||
"P1_FT01": -1.3345735264961829e-05,
|
||||
"P1_FT01Z": -4.2554413198354234e-05,
|
||||
"P1_FT02": -1.0289230789249066e-05,
|
||||
"P1_FT02Z": -1.4116856909216661e-05,
|
||||
"P1_FT03": -4.341090521713463e-06,
|
||||
"P1_FT03Z": -9.964308983887345e-06,
|
||||
"P1_LCV01D": 2.541312481372867e-06,
|
||||
"P1_LCV01Z": 2.5806433622267527e-06,
|
||||
"P1_LIT01": 7.716120912717401e-06,
|
||||
"P1_PCV01D": 2.113459306618771e-05,
|
||||
"P1_PCV01Z": 2.0632832525407433e-05,
|
||||
"P1_PCV02Z": 4.2639616636720384e-08,
|
||||
"P1_PIT01": 2.079887220863843e-05,
|
||||
"P1_PIT02": 5.003507344873546e-05,
|
||||
"P1_TIT01": 9.553657000925262e-06,
|
||||
"P1_TIT02": 2.1170357380515215e-05,
|
||||
"P2_24Vdc": 2.735770838906968e-07,
|
||||
"P2_CO_rpm": -8.124011608472296e-06,
|
||||
"P2_HILout": -4.086282393330704e-06,
|
||||
"P2_MSD": 0.0,
|
||||
"P2_SIT01": -7.418240348817199e-06,
|
||||
"P2_SIT02": -7.457826456660247e-06,
|
||||
"P2_VT01": 1.247484205979928e-07,
|
||||
"P2_VXT02": 6.53499778855353e-07,
|
||||
"P2_VXT03": 5.32656056809399e-06,
|
||||
"P2_VYT02": 9.483158480759724e-07,
|
||||
"P2_VYT03": 2.128755351566922e-06,
|
||||
"P3_FIT01": 2.2828575320599336e-05,
|
||||
"P3_LCP01D": 1.3040993552131866e-05,
|
||||
"P3_LCV01D": 3.781324885318626e-07,
|
||||
"P3_LIT01": -7.824733758742217e-06,
|
||||
"P3_PIT01": 3.210613447428708e-05,
|
||||
"P4_HT_FD": 9.197840236384403e-05,
|
||||
"P4_HT_LD": -2.4568845167931336e-08,
|
||||
"P4_HT_PO": 3.997415489949367e-07,
|
||||
"P4_LD": -6.253448074273654e-07,
|
||||
"P4_ST_FD": 2.3472181460829935e-07,
|
||||
"P4_ST_GOV": 2.494268873407866e-06,
|
||||
"P4_ST_LD": 1.6692758818969547e-06,
|
||||
"P4_ST_PO": 2.45129838870492e-06,
|
||||
"P4_ST_PT01": 1.7637202837434092e-05,
|
||||
"P4_ST_TT01": -1.9485876142550594e-05
|
||||
},
|
||||
"max_rows": 50000
|
||||
}
|
||||
@@ -233,7 +233,7 @@
|
||||
},
|
||||
"P1_B4002": {
|
||||
"mean_abs_err": 0.034608231074983564,
|
||||
"std_abs_err": 0.03795674780254288
|
||||
"std_abs_err": 0.03795674780254299
|
||||
},
|
||||
"P1_B4005": {
|
||||
"mean_abs_err": 16.56784507240046,
|
||||
@@ -249,11 +249,11 @@
|
||||
},
|
||||
"P1_FCV02Z": {
|
||||
"mean_abs_err": 37.04518503394364,
|
||||
"std_abs_err": 9.195664924295286
|
||||
"std_abs_err": 9.19566492429528
|
||||
},
|
||||
"P1_FCV03D": {
|
||||
"mean_abs_err": 0.2813618429629585,
|
||||
"std_abs_err": 3.31742916874118
|
||||
"std_abs_err": 3.317429168741179
|
||||
},
|
||||
"P1_FCV03Z": {
|
||||
"mean_abs_err": 2.7769160948375244,
|
||||
@@ -273,7 +273,7 @@
|
||||
},
|
||||
"P1_FT02Z": {
|
||||
"mean_abs_err": 389.68675585144206,
|
||||
"std_abs_err": 237.15423985136158
|
||||
"std_abs_err": 237.15423985136135
|
||||
},
|
||||
"P1_FT03": {
|
||||
"mean_abs_err": 12.373236123430615,
|
||||
@@ -305,7 +305,7 @@
|
||||
},
|
||||
"P1_PCV02Z": {
|
||||
"mean_abs_err": 0.006274240403053355,
|
||||
"std_abs_err": 0.0139416978463145
|
||||
"std_abs_err": 0.013941697846314497
|
||||
},
|
||||
"P1_PIT01": {
|
||||
"mean_abs_err": 0.03821283356563221,
|
||||
@@ -317,7 +317,7 @@
|
||||
},
|
||||
"P1_TIT01": {
|
||||
"mean_abs_err": 0.13356975062511367,
|
||||
"std_abs_err": 0.4775895846603686
|
||||
"std_abs_err": 0.4775895846603687
|
||||
},
|
||||
"P1_TIT02": {
|
||||
"mean_abs_err": 0.4872384686185143,
|
||||
@@ -325,15 +325,15 @@
|
||||
},
|
||||
"P2_24Vdc": {
|
||||
"mean_abs_err": 0.0035577079751085705,
|
||||
"std_abs_err": 0.011396984418792682
|
||||
"std_abs_err": 0.011396984418792677
|
||||
},
|
||||
"P2_CO_rpm": {
|
||||
"mean_abs_err": 9.448949609344709,
|
||||
"std_abs_err": 62.711918668665504
|
||||
"std_abs_err": 62.71191866866543
|
||||
},
|
||||
"P2_HILout": {
|
||||
"mean_abs_err": 5.394922836341834,
|
||||
"std_abs_err": 23.44018542876042
|
||||
"std_abs_err": 23.440185428760422
|
||||
},
|
||||
"P2_MSD": {
|
||||
"mean_abs_err": 0.0,
|
||||
@@ -345,11 +345,11 @@
|
||||
},
|
||||
"P2_SIT02": {
|
||||
"mean_abs_err": 0.40448069108401796,
|
||||
"std_abs_err": 14.67317239695784
|
||||
"std_abs_err": 14.673172396957842
|
||||
},
|
||||
"P2_VT01": {
|
||||
"mean_abs_err": 0.023083168987463765,
|
||||
"std_abs_err": 0.053772803716143
|
||||
"std_abs_err": 0.05377280371614296
|
||||
},
|
||||
"P2_VXT02": {
|
||||
"mean_abs_err": 0.1497719303281424,
|
||||
@@ -361,11 +361,11 @@
|
||||
},
|
||||
"P2_VYT02": {
|
||||
"mean_abs_err": 0.06072680010000164,
|
||||
"std_abs_err": 0.5584798619468422
|
||||
"std_abs_err": 0.5584798619468421
|
||||
},
|
||||
"P2_VYT03": {
|
||||
"mean_abs_err": 0.035759402078149094,
|
||||
"std_abs_err": 0.6256833854459374
|
||||
"std_abs_err": 0.6256833854459373
|
||||
},
|
||||
"P3_FIT01": {
|
||||
"mean_abs_err": 1368.8645125781227,
|
||||
@@ -393,11 +393,11 @@
|
||||
},
|
||||
"P4_HT_LD": {
|
||||
"mean_abs_err": 3.0017542197495573,
|
||||
"std_abs_err": 7.306147411731942
|
||||
"std_abs_err": 7.306147411731949
|
||||
},
|
||||
"P4_HT_PO": {
|
||||
"mean_abs_err": 4.280643741200194,
|
||||
"std_abs_err": 8.947745679875599
|
||||
"std_abs_err": 8.947745679875602
|
||||
},
|
||||
"P4_LD": {
|
||||
"mean_abs_err": 34.6203309378206,
|
||||
@@ -425,7 +425,7 @@
|
||||
},
|
||||
"P4_ST_TT01": {
|
||||
"mean_abs_err": 32.06905437515161,
|
||||
"std_abs_err": 19.934628627185333
|
||||
"std_abs_err": 19.934628627185322
|
||||
}
|
||||
},
|
||||
"discrete_invalid_counts": {
|
||||
@@ -455,5 +455,516 @@
|
||||
"P3_LL": 0,
|
||||
"P4_HT_PS": 0,
|
||||
"P4_ST_PS": 0
|
||||
},
|
||||
"continuous_ks": {
|
||||
"P1_B2004": 0.8106,
|
||||
"P1_B2016": 0.5790015625,
|
||||
"P1_B3004": 0.53125,
|
||||
"P1_B3005": 0.4782921875,
|
||||
"P1_B4002": 0.8105,
|
||||
"P1_B4005": 0.79705,
|
||||
"P1_B400B": 0.595653125,
|
||||
"P1_B4022": 0.59375,
|
||||
"P1_FCV02Z": 0.6123046875,
|
||||
"P1_FCV03D": 0.50390625,
|
||||
"P1_FCV03Z": 0.61328125,
|
||||
"P1_FT01": 0.5380859375,
|
||||
"P1_FT01Z": 0.5390625,
|
||||
"P1_FT02": 0.6317859375,
|
||||
"P1_FT02Z": 0.533153125,
|
||||
"P1_FT03": 0.53125,
|
||||
"P1_FT03Z": 0.587890625,
|
||||
"P1_LCV01D": 0.5966796875,
|
||||
"P1_LCV01Z": 0.611328125,
|
||||
"P1_LIT01": 0.6025390625,
|
||||
"P1_PCV01D": 0.5791015625,
|
||||
"P1_PCV01Z": 0.685546875,
|
||||
"P1_PCV02Z": 0.568359375,
|
||||
"P1_PIT01": 0.543871875,
|
||||
"P1_PIT02": 0.5101953125,
|
||||
"P1_TIT01": 0.501953125,
|
||||
"P1_TIT02": 0.6396484375,
|
||||
"P2_24Vdc": 0.6011625,
|
||||
"P2_CO_rpm": 0.532503125,
|
||||
"P2_HILout": 0.524140625,
|
||||
"P2_MSD": 1.0,
|
||||
"P2_SIT01": 0.6023390625,
|
||||
"P2_SIT02": 0.505659375,
|
||||
"P2_VT01": 0.5654296875,
|
||||
"P2_VXT02": 0.5615234375,
|
||||
"P2_VXT03": 0.5126953125,
|
||||
"P2_VYT02": 0.52734375,
|
||||
"P2_VYT03": 0.583984375,
|
||||
"P3_FIT01": 0.52734375,
|
||||
"P3_LCP01D": 0.568359375,
|
||||
"P3_LCV01D": 0.529296875,
|
||||
"P3_LIT01": 0.533203125,
|
||||
"P3_PIT01": 0.5654296875,
|
||||
"P4_HT_FD": 0.5009390625,
|
||||
"P4_HT_LD": 0.609375,
|
||||
"P4_HT_PO": 0.625,
|
||||
"P4_LD": 0.6279296875,
|
||||
"P4_ST_FD": 0.5097921875,
|
||||
"P4_ST_GOV": 0.577675,
|
||||
"P4_ST_LD": 0.5361328125,
|
||||
"P4_ST_PO": 0.51953125,
|
||||
"P4_ST_PT01": 0.5004765625,
|
||||
"P4_ST_TT01": 0.595703125
|
||||
},
|
||||
"continuous_quantile_diff": {
|
||||
"P1_B2004": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 0.0707,
|
||||
"q50_diff": 0.0,
|
||||
"q75_diff": 0.0,
|
||||
"q95_diff": 0.0
|
||||
},
|
||||
"P1_B2016": {
|
||||
"q05_diff": 0.20477000000000012,
|
||||
"q25_diff": 0.35475999999999996,
|
||||
"q50_diff": 0.4335500000000001,
|
||||
"q75_diff": 0.5400499999999999,
|
||||
"q95_diff": 0.4267000000000001
|
||||
},
|
||||
"P1_B3004": {
|
||||
"q05_diff": 14.80313000000001,
|
||||
"q25_diff": 19.27872000000002,
|
||||
"q50_diff": 19.27872000000002,
|
||||
"q75_diff": 18.877499999999998,
|
||||
"q95_diff": 18.877499999999998
|
||||
},
|
||||
"P1_B3005": {
|
||||
"q05_diff": 107.27929999999992,
|
||||
"q25_diff": 107.27929999999992,
|
||||
"q50_diff": 113.12176000000011,
|
||||
"q75_diff": 113.12176000000011,
|
||||
"q95_diff": 0.0
|
||||
},
|
||||
"P1_B4002": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 1.6555000000000035,
|
||||
"q50_diff": 1.6555000000000035,
|
||||
"q75_diff": 0.0,
|
||||
"q95_diff": 0.0
|
||||
},
|
||||
"P1_B4005": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 100.0,
|
||||
"q50_diff": 100.0,
|
||||
"q75_diff": 0.0,
|
||||
"q95_diff": 0.0
|
||||
},
|
||||
"P1_B400B": {
|
||||
"q05_diff": 8.937850000000001,
|
||||
"q25_diff": 2803.04775,
|
||||
"q50_diff": 22.87377000000015,
|
||||
"q75_diff": 18.10082999999986,
|
||||
"q95_diff": 12.064449999999852
|
||||
},
|
||||
"P1_B4022": {
|
||||
"q05_diff": 0.741570000000003,
|
||||
"q25_diff": 2.147590000000001,
|
||||
"q50_diff": 2.5206700000000026,
|
||||
"q75_diff": 0.883100000000006,
|
||||
"q95_diff": 0.5329200000000043
|
||||
},
|
||||
"P1_FCV02Z": {
|
||||
"q05_diff": 0.015249999999999986,
|
||||
"q25_diff": 0.015249999999999986,
|
||||
"q50_diff": 99.09821000000001,
|
||||
"q75_diff": 99.09821000000001,
|
||||
"q95_diff": 0.030520000000009873
|
||||
},
|
||||
"P1_FCV03D": {
|
||||
"q05_diff": 4.971059999999994,
|
||||
"q25_diff": 5.407429999999998,
|
||||
"q50_diff": 5.7296499999999995,
|
||||
"q75_diff": 16.195880000000002,
|
||||
"q95_diff": 1.1158599999999979
|
||||
},
|
||||
"P1_FCV03Z": {
|
||||
"q05_diff": 5.249020000000002,
|
||||
"q25_diff": 5.607600000000005,
|
||||
"q50_diff": 5.729670000000006,
|
||||
"q75_diff": 16.738889999999998,
|
||||
"q95_diff": 1.1749300000000034
|
||||
},
|
||||
"P1_FT01": {
|
||||
"q05_diff": 120.43077000000001,
|
||||
"q25_diff": 130.32031,
|
||||
"q50_diff": 83.73258999999999,
|
||||
"q75_diff": 78.01056,
|
||||
"q95_diff": 34.52304000000001
|
||||
},
|
||||
"P1_FT01Z": {
|
||||
"q05_diff": 387.17252,
|
||||
"q25_diff": 407.15109,
|
||||
"q50_diff": 187.60754000000009,
|
||||
"q75_diff": 174.5613400000001,
|
||||
"q95_diff": 75.40979000000004
|
||||
},
|
||||
"P1_FT02": {
|
||||
"q05_diff": 1.7166099999999993,
|
||||
"q25_diff": 1961.32641,
|
||||
"q50_diff": 31.47144000000003,
|
||||
"q75_diff": 24.98621000000003,
|
||||
"q95_diff": 16.78478999999993
|
||||
},
|
||||
"P1_FT02Z": {
|
||||
"q05_diff": 8.937850000000001,
|
||||
"q25_diff": 2803.04775,
|
||||
"q50_diff": 22.87377000000015,
|
||||
"q75_diff": 18.10082999999986,
|
||||
"q95_diff": 12.064449999999852
|
||||
},
|
||||
"P1_FT03": {
|
||||
"q05_diff": 57.21861999999999,
|
||||
"q25_diff": 58.36301999999998,
|
||||
"q50_diff": 70.76033999999999,
|
||||
"q75_diff": 69.23453999999998,
|
||||
"q95_diff": 3.8145700000000033
|
||||
},
|
||||
"P1_FT03Z": {
|
||||
"q05_diff": 130.45844,
|
||||
"q25_diff": 133.06768999999997,
|
||||
"q50_diff": 120.01207999999997,
|
||||
"q75_diff": 116.53325999999993,
|
||||
"q95_diff": 6.377929999999878
|
||||
},
|
||||
"P1_LCV01D": {
|
||||
"q05_diff": 4.2898,
|
||||
"q25_diff": 5.0423599999999995,
|
||||
"q50_diff": 11.84572,
|
||||
"q75_diff": 9.60605,
|
||||
"q95_diff": 4.67375
|
||||
},
|
||||
"P1_LCV01Z": {
|
||||
"q05_diff": 4.40978,
|
||||
"q25_diff": 5.241390000000001,
|
||||
"q50_diff": 5.95092,
|
||||
"q75_diff": 9.460459999999998,
|
||||
"q95_diff": 4.226689999999998
|
||||
},
|
||||
"P1_LIT01": {
|
||||
"q05_diff": 34.6062,
|
||||
"q25_diff": 38.28658999999999,
|
||||
"q50_diff": 35.42405000000002,
|
||||
"q75_diff": 33.32828000000001,
|
||||
"q95_diff": 29.698980000000006
|
||||
},
|
||||
"P1_PCV01D": {
|
||||
"q05_diff": 10.95326,
|
||||
"q25_diff": 15.403260000000003,
|
||||
"q50_diff": 56.51593,
|
||||
"q75_diff": 52.88774,
|
||||
"q95_diff": 46.32521
|
||||
},
|
||||
"P1_PCV01Z": {
|
||||
"q05_diff": 10.925290000000004,
|
||||
"q25_diff": 15.548700000000004,
|
||||
"q50_diff": 18.63098,
|
||||
"q75_diff": 51.9104,
|
||||
"q95_diff": 45.50171
|
||||
},
|
||||
"P1_PCV02Z": {
|
||||
"q05_diff": 0.007629999999998915,
|
||||
"q25_diff": 0.007629999999998915,
|
||||
"q50_diff": 0.021849999999998815,
|
||||
"q75_diff": 0.0228900000000003,
|
||||
"q95_diff": 0.0228900000000003
|
||||
},
|
||||
"P1_PIT01": {
|
||||
"q05_diff": 0.27142999999999995,
|
||||
"q25_diff": 0.3475299999999999,
|
||||
"q50_diff": 0.3565400000000001,
|
||||
"q75_diff": 0.32497,
|
||||
"q95_diff": 0.27395000000000014
|
||||
},
|
||||
"P1_PIT02": {
|
||||
"q05_diff": 0.03356999999999999,
|
||||
"q25_diff": 0.10451999999999997,
|
||||
"q50_diff": 2.0233700000000003,
|
||||
"q75_diff": 2.06985,
|
||||
"q95_diff": 2.06909
|
||||
},
|
||||
"P1_TIT01": {
|
||||
"q05_diff": 0.07629999999999626,
|
||||
"q25_diff": 0.4272399999999976,
|
||||
"q50_diff": 0.991819999999997,
|
||||
"q75_diff": 0.5187900000000027,
|
||||
"q95_diff": 0.09155000000000513
|
||||
},
|
||||
"P1_TIT02": {
|
||||
"q05_diff": 0.09155000000000513,
|
||||
"q25_diff": 0.4577600000000004,
|
||||
"q50_diff": 1.2207000000000008,
|
||||
"q75_diff": 3.2806399999999982,
|
||||
"q95_diff": 1.2207099999999969
|
||||
},
|
||||
"P2_24Vdc": {
|
||||
"q05_diff": 0.00946000000000069,
|
||||
"q25_diff": 0.012100000000000222,
|
||||
"q50_diff": 0.014739999999999753,
|
||||
"q75_diff": 0.013829999999998677,
|
||||
"q95_diff": 0.010680000000000689
|
||||
},
|
||||
"P2_CO_rpm": {
|
||||
"q05_diff": 68.91000000000349,
|
||||
"q25_diff": 88.13999999999942,
|
||||
"q50_diff": 67.0,
|
||||
"q75_diff": 54.0,
|
||||
"q95_diff": 37.0
|
||||
},
|
||||
"P2_HILout": {
|
||||
"q05_diff": 21.350099999999998,
|
||||
"q25_diff": 31.207269999999994,
|
||||
"q50_diff": 35.31494000000009,
|
||||
"q75_diff": 21.63695999999993,
|
||||
"q95_diff": 14.538569999999936
|
||||
},
|
||||
"P2_MSD": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 0.0,
|
||||
"q50_diff": 0.0,
|
||||
"q75_diff": 0.0,
|
||||
"q95_diff": 0.0
|
||||
},
|
||||
"P2_SIT01": {
|
||||
"q05_diff": 12.580000000000041,
|
||||
"q25_diff": 16.710000000000036,
|
||||
"q50_diff": 17.850000000000023,
|
||||
"q75_diff": 16.899999999999977,
|
||||
"q95_diff": 13.25
|
||||
},
|
||||
"P2_SIT02": {
|
||||
"q05_diff": 12.788270000000011,
|
||||
"q25_diff": 16.65368000000001,
|
||||
"q50_diff": 14.857610000000022,
|
||||
"q75_diff": 16.707580000000007,
|
||||
"q95_diff": 13.430229999999938
|
||||
},
|
||||
"P2_VT01": {
|
||||
"q05_diff": 0.015200000000000102,
|
||||
"q25_diff": 0.046990000000000975,
|
||||
"q50_diff": 0.13512000000000057,
|
||||
"q75_diff": 0.07174000000000014,
|
||||
"q95_diff": 0.03888000000000069
|
||||
},
|
||||
"P2_VXT02": {
|
||||
"q05_diff": 0.08729999999999993,
|
||||
"q25_diff": 0.2607000000000004,
|
||||
"q50_diff": 0.6693000000000002,
|
||||
"q75_diff": 0.9367000000000001,
|
||||
"q95_diff": 0.7728999999999999
|
||||
},
|
||||
"P2_VXT03": {
|
||||
"q05_diff": 0.07350000000000012,
|
||||
"q25_diff": 0.18210000000000015,
|
||||
"q50_diff": 0.4036000000000002,
|
||||
"q75_diff": 1.0967,
|
||||
"q95_diff": 0.9957
|
||||
},
|
||||
"P2_VYT02": {
|
||||
"q05_diff": 0.3511000000000002,
|
||||
"q25_diff": 0.5322,
|
||||
"q50_diff": 0.9708999999999999,
|
||||
"q75_diff": 0.6384999999999996,
|
||||
"q95_diff": 0.4569000000000001
|
||||
},
|
||||
"P2_VYT03": {
|
||||
"q05_diff": 0.6910999999999996,
|
||||
"q25_diff": 0.8129999999999997,
|
||||
"q50_diff": 0.8028000000000004,
|
||||
"q75_diff": 0.5364000000000004,
|
||||
"q95_diff": 0.41800000000000015
|
||||
},
|
||||
"P3_FIT01": {
|
||||
"q05_diff": 2.0,
|
||||
"q25_diff": 4.0,
|
||||
"q50_diff": 76.0,
|
||||
"q75_diff": 2735.0,
|
||||
"q95_diff": 382.0
|
||||
},
|
||||
"P3_LCP01D": {
|
||||
"q05_diff": 8.0,
|
||||
"q25_diff": 56.0,
|
||||
"q50_diff": 1760.0,
|
||||
"q75_diff": 4112.0,
|
||||
"q95_diff": 216.0
|
||||
},
|
||||
"P3_LCV01D": {
|
||||
"q05_diff": 16.0,
|
||||
"q25_diff": 336.0,
|
||||
"q50_diff": 9488.0,
|
||||
"q75_diff": 3376.0,
|
||||
"q95_diff": 1584.0
|
||||
},
|
||||
"P3_LIT01": {
|
||||
"q05_diff": 1310.0,
|
||||
"q25_diff": 4632.0,
|
||||
"q50_diff": 6346.0,
|
||||
"q75_diff": 3409.0,
|
||||
"q95_diff": 473.0
|
||||
},
|
||||
"P3_PIT01": {
|
||||
"q05_diff": 2.0,
|
||||
"q25_diff": 3.0,
|
||||
"q50_diff": 4.0,
|
||||
"q75_diff": 2855.0,
|
||||
"q95_diff": 259.0
|
||||
},
|
||||
"P4_HT_FD": {
|
||||
"q05_diff": 0.00863,
|
||||
"q25_diff": 0.00963,
|
||||
"q50_diff": 0.007980000000000001,
|
||||
"q75_diff": 0.00971,
|
||||
"q95_diff": 0.00881
|
||||
},
|
||||
"P4_HT_LD": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 0.0,
|
||||
"q50_diff": 54.74537,
|
||||
"q75_diff": 14.424189999999996,
|
||||
"q95_diff": 6.669560000000004
|
||||
},
|
||||
"P4_HT_PO": {
|
||||
"q05_diff": 0.0,
|
||||
"q25_diff": 1.35638,
|
||||
"q50_diff": 43.402800000000006,
|
||||
"q75_diff": 15.426150000000007,
|
||||
"q95_diff": 7.179570000000012
|
||||
},
|
||||
"P4_LD": {
|
||||
"q05_diff": 38.17633000000001,
|
||||
"q25_diff": 89.59051,
|
||||
"q50_diff": 137.02618,
|
||||
"q75_diff": 84.97899999999998,
|
||||
"q95_diff": 39.27954
|
||||
},
|
||||
"P4_ST_FD": {
|
||||
"q05_diff": 0.00547,
|
||||
"q25_diff": 0.00697,
|
||||
"q50_diff": 0.00664,
|
||||
"q75_diff": 0.00685,
|
||||
"q95_diff": 0.0054800000000000005
|
||||
},
|
||||
"P4_ST_GOV": {
|
||||
"q05_diff": 2299.0,
|
||||
"q25_diff": 4178.0,
|
||||
"q50_diff": 7730.490000000002,
|
||||
"q75_diff": 7454.919999999998,
|
||||
"q95_diff": 5812.0
|
||||
},
|
||||
"P4_ST_LD": {
|
||||
"q05_diff": 39.333740000000034,
|
||||
"q25_diff": 76.64203000000003,
|
||||
"q50_diff": 100.76677999999998,
|
||||
"q75_diff": 137.83997,
|
||||
"q95_diff": 105.36016999999998
|
||||
},
|
||||
"P4_ST_PO": {
|
||||
"q05_diff": 43.02301,
|
||||
"q25_diff": 78.10687000000001,
|
||||
"q50_diff": 136.67437999999999,
|
||||
"q75_diff": 139.8797,
|
||||
"q95_diff": 107.81970000000001
|
||||
},
|
||||
"P4_ST_PT01": {
|
||||
"q05_diff": 83.0,
|
||||
"q25_diff": 89.54999999999927,
|
||||
"q50_diff": 61.399999999999636,
|
||||
"q75_diff": 104.46999999999935,
|
||||
"q95_diff": 76.94000000000051
|
||||
},
|
||||
"P4_ST_TT01": {
|
||||
"q05_diff": 14.0,
|
||||
"q25_diff": 48.0,
|
||||
"q50_diff": 88.0,
|
||||
"q75_diff": 2.0,
|
||||
"q95_diff": 2.0
|
||||
}
|
||||
},
|
||||
"continuous_lag1_diff": {
|
||||
"P1_B2004": 0.9718697145204541,
|
||||
"P1_B2016": 1.0096989619256902,
|
||||
"P1_B3004": 0.987235023697155,
|
||||
"P1_B3005": 1.0293153257905971,
|
||||
"P1_B4002": 1.0101720557598692,
|
||||
"P1_B4005": 1.0069799798786847,
|
||||
"P1_B400B": 1.0100831396536611,
|
||||
"P1_B4022": 1.005260911857099,
|
||||
"P1_FCV02Z": 1.0208150011699488,
|
||||
"P1_FCV03D": 0.9677639878142119,
|
||||
"P1_FCV03Z": 1.0448432960637175,
|
||||
"P1_FT01": 1.0174834364370597,
|
||||
"P1_FT01Z": 0.9623633857501661,
|
||||
"P1_FT02": 0.9492143748470167,
|
||||
"P1_FT02Z": 1.0248788433686373,
|
||||
"P1_FT03": 0.9967197043976962,
|
||||
"P1_FT03Z": 1.0042718034976392,
|
||||
"P1_LCV01D": 0.9767924561102216,
|
||||
"P1_LCV01Z": 1.0238429714137112,
|
||||
"P1_LIT01": 0.994524637441244,
|
||||
"P1_PCV01D": 0.9731502416003561,
|
||||
"P1_PCV01Z": 0.9987281311850285,
|
||||
"P1_PCV02Z": 0.5767325423170566,
|
||||
"P1_PIT01": 1.0052067307895398,
|
||||
"P1_PIT02": 1.07011185942615,
|
||||
"P1_TIT01": 1.0555120205346828,
|
||||
"P1_TIT02": 0.9917823846962297,
|
||||
"P2_24Vdc": 0.005526132896643308,
|
||||
"P2_CO_rpm": 0.40287161666960586,
|
||||
"P2_HILout": 0.25625640743099676,
|
||||
"P2_MSD": 0.0,
|
||||
"P2_SIT01": 0.7010364996672233,
|
||||
"P2_SIT02": 0.7090423064483174,
|
||||
"P2_VT01": 0.8506494578213937,
|
||||
"P2_VXT02": 0.8094431207350834,
|
||||
"P2_VXT03": 0.8417667674176789,
|
||||
"P2_VYT02": 0.8415060810530584,
|
||||
"P2_VYT03": 0.8550842087621501,
|
||||
"P3_FIT01": 0.9816722571345234,
|
||||
"P3_LCP01D": 0.9961743411760093,
|
||||
"P3_LCV01D": 0.9663083457613558,
|
||||
"P3_LIT01": 1.0191146983629609,
|
||||
"P3_PIT01": 0.9827441734544354,
|
||||
"P4_HT_FD": 0.26359661930055545,
|
||||
"P4_HT_LD": 1.0050543113208599,
|
||||
"P4_HT_PO": 1.0522741078670352,
|
||||
"P4_LD": 0.9692264478458912,
|
||||
"P4_ST_FD": 0.37863299269548845,
|
||||
"P4_ST_GOV": 1.0392419101031738,
|
||||
"P4_ST_LD": 0.9837987258388408,
|
||||
"P4_ST_PO": 1.0089254379868162,
|
||||
"P4_ST_PT01": 1.0237355221468971,
|
||||
"P4_ST_TT01": 0.9905047944110966
|
||||
},
|
||||
"discrete_jsd": {
|
||||
"P1_FCV01D": 0.10738201625126076,
|
||||
"P1_FCV01Z": 0.24100377687195113,
|
||||
"P1_FCV02D": 0.059498960127610634,
|
||||
"P1_PCV02D": 0.0,
|
||||
"P1_PP01AD": 0.0,
|
||||
"P1_PP01AR": 0.0,
|
||||
"P1_PP01BD": 0.0,
|
||||
"P1_PP01BR": 0.0,
|
||||
"P1_PP02D": 0.0,
|
||||
"P1_PP02R": 0.0,
|
||||
"P1_STSP": 0.0,
|
||||
"P2_ASD": 0.0,
|
||||
"P2_AutoGO": 0.0,
|
||||
"P2_Emerg": 0.0,
|
||||
"P2_ManualGO": 0.0,
|
||||
"P2_OnOff": 0.0,
|
||||
"P2_RTR": 0.0,
|
||||
"P2_TripEx": 0.0,
|
||||
"P2_VTR01": 0.0,
|
||||
"P2_VTR02": 0.0,
|
||||
"P2_VTR03": 0.0,
|
||||
"P2_VTR04": 0.0,
|
||||
"P3_LH": 0.0,
|
||||
"P3_LL": 0.0,
|
||||
"P4_HT_PS": 0.0,
|
||||
"P4_ST_PS": 0.0
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,8 @@ def main():
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
data_glob = cfg.get("data_glob", "")
|
||||
data_path = cfg.get("data_path", "")
|
||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||
run(
|
||||
@@ -70,6 +72,10 @@ def main():
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
ref = data_glob if data_glob else data_path
|
||||
if ref:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||
else:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
run([sys.executable, str(base_dir / "plot_loss.py")])
|
||||
|
||||
|
||||
@@ -47,6 +47,13 @@ def main():
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
||||
model_num_layers = int(cfg.get("model_num_layers", 1))
|
||||
model_dropout = float(cfg.get("model_dropout", 0.0))
|
||||
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
||||
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
||||
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
@@ -67,6 +74,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=model_time_dim,
|
||||
hidden_dim=model_hidden_dim,
|
||||
num_layers=model_num_layers,
|
||||
dropout=model_dropout,
|
||||
ff_mult=model_ff_mult,
|
||||
pos_dim=model_pos_dim,
|
||||
use_pos_embed=model_use_pos,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
|
||||
@@ -49,8 +49,17 @@ DEFAULTS = {
|
||||
"use_condition": True,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": True,
|
||||
"use_tanh_eps": False,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": True,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +153,7 @@ def main():
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
@@ -164,6 +174,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(config.get("model_time_dim", 64)),
|
||||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||
num_layers=int(config.get("model_num_layers", 1)),
|
||||
dropout=float(config.get("model_dropout", 0.0)),
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
@@ -198,6 +215,8 @@ def main():
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=use_condition,
|
||||
transforms=transforms,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
if use_condition:
|
||||
@@ -215,7 +234,13 @@ def main():
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
x_disc,
|
||||
t,
|
||||
mask_tokens,
|
||||
int(config["timesteps"]),
|
||||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||
)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user