update
This commit is contained in:
@@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto
|
||||
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
|
||||
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
|
||||
- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json).
|
||||
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
|
||||
- The script only samples the first 5000 rows to stay fast.
|
||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
"vocab_path": "./results/disc_vocab.json",
|
||||
"out_dir": "./results",
|
||||
"device": "auto",
|
||||
"timesteps": 400,
|
||||
"timesteps": 600,
|
||||
"batch_size": 128,
|
||||
"seq_len": 128,
|
||||
"epochs": 8,
|
||||
"max_batches": 3000,
|
||||
"epochs": 10,
|
||||
"max_batches": 4000,
|
||||
"lambda": 0.5,
|
||||
"lr": 0.0005,
|
||||
"seed": 1337,
|
||||
@@ -23,8 +23,17 @@
|
||||
"use_condition": true,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": true,
|
||||
"use_tanh_eps": false,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
"sample_batch_size": 8,
|
||||
"sample_seq_len": 128
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||
yield row
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
def _stream_basic_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
|
||||
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
|
||||
count = 0
|
||||
):
|
||||
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
m3 = {c: 0.0 for c in cont_cols}
|
||||
vmin = {c: float("inf") for c in cont_cols}
|
||||
vmax = {c: float("-inf") for c in cont_cols}
|
||||
int_like = {c: True for c in cont_cols}
|
||||
max_decimals = {c: 0 for c in cont_cols}
|
||||
all_pos = {c: True for c in cont_cols}
|
||||
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
count += 1
|
||||
for c in cont_cols:
|
||||
raw = row[c]
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
x = float(raw)
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / count
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
if x <= 0:
|
||||
all_pos[c] = False
|
||||
if x < vmin[c]:
|
||||
vmin[c] = x
|
||||
if x > vmax[c]:
|
||||
vmax[c] = x
|
||||
if int_like[c] and abs(x - round(x)) > 1e-9:
|
||||
int_like[c] = False
|
||||
# track decimal places from raw string if possible
|
||||
if "e" in raw or "E" in raw:
|
||||
# scientific notation, skip precision inference
|
||||
continue
|
||||
if "." in raw:
|
||||
if "e" not in raw and "E" not in raw and "." in raw:
|
||||
dec = raw.split(".", 1)[1].rstrip("0")
|
||||
if len(dec) > max_decimals[c]:
|
||||
max_decimals[c] = len(dec)
|
||||
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
delta_n = delta / n
|
||||
term1 = delta * delta_n * (n - 1)
|
||||
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
|
||||
m2[c] += term1
|
||||
mean[c] += delta_n
|
||||
count[c] = n
|
||||
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
# finalize std/skew
|
||||
std = {}
|
||||
skew = {}
|
||||
for c in cont_cols:
|
||||
if count > 1:
|
||||
var = m2[c] / (count - 1)
|
||||
n = count[c]
|
||||
if n > 1:
|
||||
var = m2[c] / (n - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
# replace infs if column had no valid values
|
||||
if n > 2 and m2[c] > 0:
|
||||
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
|
||||
else:
|
||||
skew[c] = 0.0
|
||||
|
||||
for c in cont_cols:
|
||||
if vmin[c] == float("inf"):
|
||||
vmin[c] = 0.0
|
||||
if vmax[c] == float("-inf"):
|
||||
vmax[c] = 0.0
|
||||
return mean, std, vmin, vmax, int_like, max_decimals
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"m2": m2,
|
||||
"m3": m3,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"skew": skew,
|
||||
"all_pos": all_pos,
|
||||
}
|
||||
|
||||
|
||||
def choose_cont_transforms(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
skew_threshold: float = 1.5,
|
||||
range_ratio_threshold: float = 1e3,
|
||||
):
|
||||
"""Pick per-column transform (currently log1p or none) based on skew/range."""
|
||||
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
transforms = {}
|
||||
for c in cont_cols:
|
||||
if not stats["all_pos"][c]:
|
||||
transforms[c] = "none"
|
||||
continue
|
||||
skew = abs(stats["skew"][c])
|
||||
vmin = stats["min"][c]
|
||||
vmax = stats["max"][c]
|
||||
ratio = (vmax / vmin) if vmin > 0 else 0.0
|
||||
if skew >= skew_threshold or ratio >= range_ratio_threshold:
|
||||
transforms[c] = "log1p"
|
||||
else:
|
||||
transforms[c] = "none"
|
||||
return transforms, stats
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
|
||||
# First pass (raw) for metadata and raw mean/std
|
||||
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
||||
|
||||
# Optional transform selection
|
||||
if transforms is None:
|
||||
transforms = {c: "none" for c in cont_cols}
|
||||
|
||||
# Second pass for transformed mean/std
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in cont_cols:
|
||||
raw_val = row[c]
|
||||
if raw_val is None or raw_val == "":
|
||||
continue
|
||||
x = float(raw_val)
|
||||
if transforms.get(c) == "log1p":
|
||||
if x < 0:
|
||||
x = 0.0
|
||||
x = math.log1p(x)
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / n
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
count[c] = n
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
std = {}
|
||||
for c in cont_cols:
|
||||
if count[c] > 1:
|
||||
var = m2[c] / (count[c] - 1)
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"raw_mean": raw["mean"],
|
||||
"raw_std": raw["std"],
|
||||
"min": raw["min"],
|
||||
"max": raw["max"],
|
||||
"int_like": raw["int_like"],
|
||||
"max_decimals": raw["max_decimals"],
|
||||
"transform": transforms,
|
||||
"skew": raw["skew"],
|
||||
"all_pos": raw["all_pos"],
|
||||
"max_rows": max_rows,
|
||||
}
|
||||
|
||||
|
||||
def build_vocab(
|
||||
@@ -130,8 +243,19 @@ def build_disc_stats(
|
||||
return vocab, top_token
|
||||
|
||||
|
||||
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||
def normalize_cont(
|
||||
x,
|
||||
cont_cols: List[str],
|
||||
mean: Dict[str, float],
|
||||
std: Dict[str, float],
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
import torch
|
||||
|
||||
if transforms:
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
return (x - mean_t) / std_t
|
||||
@@ -148,19 +272,34 @@ def windowed_batches(
|
||||
seq_len: int,
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
shuffle_buffer: int = 0,
|
||||
):
|
||||
import torch
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
buffer = []
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
def flush_seq():
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
|
||||
def flush_seq(file_id: int):
|
||||
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
|
||||
if len(seq_cont) == seq_len:
|
||||
if shuffle_buffer and shuffle_buffer > 0:
|
||||
buffer.append((list(seq_cont), list(seq_disc), file_id))
|
||||
if len(buffer) >= shuffle_buffer:
|
||||
idx = random.randrange(len(buffer))
|
||||
seq_c, seq_d, seq_f = buffer.pop(idx)
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
else:
|
||||
batch_cont.append(seq_cont)
|
||||
batch_disc.append(seq_disc)
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
@@ -173,13 +312,11 @@ def windowed_batches(
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
flush_seq(file_id)
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
@@ -195,4 +332,29 @@ def windowed_batches(
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
# Flush any remaining buffered sequences
|
||||
if shuffle_buffer and buffer:
|
||||
random.shuffle(buffer)
|
||||
for seq_c, seq_d, seq_f in buffer:
|
||||
batch_cont.append(seq_c)
|
||||
batch_disc.append(seq_d)
|
||||
if return_file_id:
|
||||
batch_file.append(seq_f)
|
||||
if len(batch_cont) == batch_size:
|
||||
import torch
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
else:
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
|
||||
# Drop last partial batch for simplicity
|
||||
|
||||
@@ -5,8 +5,9 @@ import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
|
||||
|
||||
def load_json(path: str) -> Dict:
|
||||
@@ -28,6 +29,8 @@ def parse_args():
|
||||
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
|
||||
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -55,6 +58,62 @@ def finalize_stats(stats):
|
||||
return out
|
||||
|
||||
|
||||
def js_divergence(p, q, eps: float = 1e-12) -> float:
|
||||
p = [max(x, eps) for x in p]
|
||||
q = [max(x, eps) for x in q]
|
||||
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
|
||||
def kl(a, b):
|
||||
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
|
||||
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
|
||||
|
||||
|
||||
def ks_statistic(x: List[float], y: List[float]) -> float:
|
||||
if not x or not y:
|
||||
return 0.0
|
||||
x_sorted = sorted(x)
|
||||
y_sorted = sorted(y)
|
||||
n = len(x_sorted)
|
||||
m = len(y_sorted)
|
||||
i = j = 0
|
||||
cdf_x = cdf_y = 0.0
|
||||
d = 0.0
|
||||
while i < n and j < m:
|
||||
if x_sorted[i] <= y_sorted[j]:
|
||||
i += 1
|
||||
cdf_x = i / n
|
||||
else:
|
||||
j += 1
|
||||
cdf_y = j / m
|
||||
d = max(d, abs(cdf_x - cdf_y))
|
||||
return d
|
||||
|
||||
|
||||
def lag1_corr(values: List[float]) -> float:
|
||||
if len(values) < 3:
|
||||
return 0.0
|
||||
x = values[:-1]
|
||||
y = values[1:]
|
||||
mean_x = sum(x) / len(x)
|
||||
mean_y = sum(y) / len(y)
|
||||
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
|
||||
den_x = sum((xi - mean_x) ** 2 for xi in x)
|
||||
den_y = sum((yi - mean_y) ** 2 for yi in y)
|
||||
if den_x <= 0 or den_y <= 0:
|
||||
return 0.0
|
||||
return num / math.sqrt(den_x * den_y)
|
||||
|
||||
|
||||
def resolve_reference_path(path: str) -> Optional[str]:
|
||||
if not path:
|
||||
return None
|
||||
if any(ch in path for ch in ["*", "?", "["]):
|
||||
base = Path(path).parent
|
||||
pat = Path(path).name
|
||||
matches = sorted(base.glob(pat))
|
||||
return str(matches[0]) if matches else None
|
||||
return str(path)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
@@ -63,13 +122,18 @@ def main():
|
||||
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
|
||||
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
|
||||
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
|
||||
if args.reference and not Path(args.reference).is_absolute():
|
||||
args.reference = str((base_dir / args.reference).resolve())
|
||||
ref_path = resolve_reference_path(args.reference)
|
||||
split = load_json(args.split)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats_ref = load_json(args.stats)["mean"]
|
||||
std_ref = load_json(args.stats)["std"]
|
||||
stats_json = load_json(args.stats)
|
||||
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
|
||||
std_ref = stats_json.get("raw_std", stats_json.get("std"))
|
||||
transforms = stats_json.get("transform", {})
|
||||
vocab = load_json(args.vocab)["vocab"]
|
||||
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||
|
||||
@@ -89,6 +153,8 @@ def main():
|
||||
except Exception:
|
||||
v = 0.0
|
||||
update_stats(cont_stats, c, v)
|
||||
if ref_path:
|
||||
pass
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
disc_invalid[c] += 1
|
||||
@@ -112,6 +178,81 @@ def main():
|
||||
"discrete_invalid_counts": disc_invalid,
|
||||
}
|
||||
|
||||
# Optional richer metrics using reference data
|
||||
if ref_path:
|
||||
ref_cont = {c: [] for c in cont_cols}
|
||||
ref_disc = {c: {} for c in disc_cols}
|
||||
gen_cont = {c: [] for c in cont_cols}
|
||||
gen_disc = {c: {} for c in disc_cols}
|
||||
|
||||
with open_csv(args.generated) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
gen_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
gen_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
||||
|
||||
with open_csv(ref_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if time_col in row:
|
||||
row.pop(time_col, None)
|
||||
for c in cont_cols:
|
||||
try:
|
||||
ref_cont[c].append(float(row[c]))
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
if args.max_rows and i + 1 >= args.max_rows:
|
||||
break
|
||||
|
||||
# Continuous metrics: KS + quantiles + lag1 correlation
|
||||
cont_ks = {}
|
||||
cont_quant = {}
|
||||
cont_lag1 = {}
|
||||
for c in cont_cols:
|
||||
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
|
||||
ref_sorted = sorted(ref_cont[c])
|
||||
gen_sorted = sorted(gen_cont[c])
|
||||
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
|
||||
def qval(arr, q):
|
||||
if not arr:
|
||||
return 0.0
|
||||
idx = int(q * (len(arr) - 1))
|
||||
return arr[idx]
|
||||
cont_quant[c] = {
|
||||
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
|
||||
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
|
||||
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
|
||||
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
|
||||
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
|
||||
}
|
||||
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
|
||||
|
||||
# Discrete metrics: JSD over vocab
|
||||
disc_jsd = {}
|
||||
for c in disc_cols:
|
||||
vocab_vals = list(vocab_sets[c])
|
||||
gen_total = sum(gen_disc[c].values()) or 1
|
||||
ref_total = sum(ref_disc[c].values()) or 1
|
||||
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
|
||||
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
|
||||
disc_jsd[c] = js_divergence(p, q)
|
||||
|
||||
report["continuous_ks"] = cont_ks
|
||||
report["continuous_quantile_diff"] = cont_quant
|
||||
report["continuous_lag1_diff"] = cont_lag1
|
||||
report["discrete_jsd"] = disc_jsd
|
||||
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ def main():
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
@@ -141,6 +142,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||
num_layers=int(cfg.get("model_num_layers", 1)),
|
||||
dropout=float(cfg.get("model_dropout", 0.0)),
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
@@ -220,6 +228,9 @@ def main():
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
# clamp to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
|
||||
@@ -35,11 +35,14 @@ def q_sample_discrete(
|
||||
t: torch.Tensor,
|
||||
mask_tokens: torch.Tensor,
|
||||
max_t: int,
|
||||
mask_scale: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Randomly mask discrete tokens with a cosine schedule over t."""
|
||||
bsz = x0.size(0)
|
||||
# cosine schedule: p(0)=0, p(max_t)=1
|
||||
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
|
||||
if mask_scale != 1.0:
|
||||
p = torch.clamp(p * mask_scale, 0.0, 1.0)
|
||||
p = p.view(bsz, 1, 1)
|
||||
mask = torch.rand_like(x0.float()) < p
|
||||
x_masked = x0.clone()
|
||||
@@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_vocab_sizes: List[int],
|
||||
time_dim: int = 64,
|
||||
hidden_dim: int = 256,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
ff_mult: int = 2,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
||||
self.use_tanh_eps = use_tanh_eps
|
||||
self.eps_scale = eps_scale
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
@@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||
|
||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.post_norm = nn.LayerNorm(hidden_dim)
|
||||
self.post_ff = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim * ff_mult),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim * ff_mult, hidden_dim),
|
||||
)
|
||||
|
||||
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
||||
self.disc_heads = nn.ModuleList([
|
||||
@@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module):
|
||||
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||
time_emb = self.time_embed(t)
|
||||
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
pos_emb = None
|
||||
if self.use_pos_embed and self.pos_dim > 0:
|
||||
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
|
||||
|
||||
disc_embs = []
|
||||
for i, emb in enumerate(self.disc_embeds):
|
||||
@@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module):
|
||||
|
||||
cont_feat = self.cont_proj(x_cont)
|
||||
parts = [cont_feat, disc_feat, time_emb]
|
||||
if pos_emb is not None:
|
||||
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
||||
if cond_feat is not None:
|
||||
parts.append(cond_feat)
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
out, _ = self.backbone(feat)
|
||||
out = self.post_norm(out)
|
||||
out = out + self.post_ff(out)
|
||||
|
||||
eps_pred = self.cont_head(out)
|
||||
if self.use_tanh_eps:
|
||||
eps_pred = torch.tanh(eps_pred) * self.eps_scale
|
||||
logits = [head(out) for head in self.disc_heads]
|
||||
return eps_pred, logits
|
||||
|
||||
@staticmethod
|
||||
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
|
||||
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
||||
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
|
||||
pe = torch.zeros(seq_len, dim, device=device)
|
||||
pe[:, 0::2] = torch.sin(pos * div)
|
||||
pe[:, 1::2] = torch.cos(pos * div)
|
||||
return pe
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None):
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"max_rows": max_rows,
|
||||
"mean": cont_stats["mean"],
|
||||
"std": cont_stats["std"],
|
||||
"raw_mean": cont_stats["raw_mean"],
|
||||
"raw_std": cont_stats["raw_std"],
|
||||
"min": cont_stats["min"],
|
||||
"max": cont_stats["max"],
|
||||
"int_like": cont_stats["int_like"],
|
||||
"max_decimals": cont_stats["max_decimals"],
|
||||
"transform": cont_stats["transform"],
|
||||
"skew": cont_stats["skew"],
|
||||
"max_rows": cont_stats["max_rows"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
"P3_PIT01": 668.9722350000003,
|
||||
"P4_HT_FD": -0.00010012580000000082,
|
||||
"P4_HT_LD": 35.41945000099953,
|
||||
"P4_HT_PO": 35.4085699912002,
|
||||
"P4_HT_PO": 2.6391372939040414,
|
||||
"P4_LD": 365.3833745803986,
|
||||
"P4_ST_FD": -6.5205999999999635e-06,
|
||||
"P4_ST_GOV": 17801.81294499996,
|
||||
@@ -100,7 +100,7 @@
|
||||
"P3_PIT01": 1168.1071264424027,
|
||||
"P4_HT_FD": 0.002032582380617592,
|
||||
"P4_HT_LD": 33.212361169253235,
|
||||
"P4_HT_PO": 31.187825914515162,
|
||||
"P4_HT_PO": 1.7636196192459512,
|
||||
"P4_LD": 59.736616589045646,
|
||||
"P4_ST_FD": 0.0016428787127432496,
|
||||
"P4_ST_GOV": 1740.5997458128215,
|
||||
@@ -109,6 +109,116 @@
|
||||
"P4_ST_PT01": 22.459962818146252,
|
||||
"P4_ST_TT01": 24.745939350221477
|
||||
},
|
||||
"raw_mean": {
|
||||
"P1_B2004": 0.08649086820000026,
|
||||
"P1_B2016": 1.376161456000001,
|
||||
"P1_B3004": 396.1861596906018,
|
||||
"P1_B3005": 1037.372384413793,
|
||||
"P1_B4002": 32.564872940799994,
|
||||
"P1_B4005": 65.98190757240047,
|
||||
"P1_B400B": 1925.0391570245934,
|
||||
"P1_B4022": 36.28908066800001,
|
||||
"P1_FCV02Z": 21.744261118400036,
|
||||
"P1_FCV03D": 57.36123274140044,
|
||||
"P1_FCV03Z": 58.05084519640002,
|
||||
"P1_FT01": 184.18615112319728,
|
||||
"P1_FT01Z": 851.8781750705965,
|
||||
"P1_FT02": 1255.8572173544069,
|
||||
"P1_FT02Z": 1925.0210755194114,
|
||||
"P1_FT03": 269.37285885780574,
|
||||
"P1_FT03Z": 1037.366172230601,
|
||||
"P1_LCV01D": 11.228849048599963,
|
||||
"P1_LCV01Z": 10.991610181600016,
|
||||
"P1_LIT01": 396.8845311109994,
|
||||
"P1_PCV01D": 53.80101618419986,
|
||||
"P1_PCV01Z": 54.646640287199595,
|
||||
"P1_PCV02Z": 12.017773542800072,
|
||||
"P1_PIT01": 1.3692859488000075,
|
||||
"P1_PIT02": 0.44459071260000227,
|
||||
"P1_TIT01": 35.64255813999988,
|
||||
"P1_TIT02": 36.44807823060023,
|
||||
"P2_24Vdc": 28.0280019013999,
|
||||
"P2_CO_rpm": 54105.64434999997,
|
||||
"P2_HILout": 712.0588667425922,
|
||||
"P2_MSD": 763.19324,
|
||||
"P2_SIT01": 778.7769850000013,
|
||||
"P2_SIT02": 778.7778935471981,
|
||||
"P2_VT01": 11.914949448200044,
|
||||
"P2_VXT02": -3.5267871940000175,
|
||||
"P2_VXT03": -1.5520904921999914,
|
||||
"P2_VYT02": 3.796112737600002,
|
||||
"P2_VYT03": 6.121691697000018,
|
||||
"P3_FIT01": 1168.2528800000014,
|
||||
"P3_LCP01D": 4675.465239999989,
|
||||
"P3_LCV01D": 7445.208720000017,
|
||||
"P3_LIT01": 13728.982314999852,
|
||||
"P3_PIT01": 668.9722350000003,
|
||||
"P4_HT_FD": -0.00010012580000000082,
|
||||
"P4_HT_LD": 35.41945000099953,
|
||||
"P4_HT_PO": 35.4085699912002,
|
||||
"P4_LD": 365.3833745803986,
|
||||
"P4_ST_FD": -6.5205999999999635e-06,
|
||||
"P4_ST_GOV": 17801.81294499996,
|
||||
"P4_ST_LD": 329.83259218199964,
|
||||
"P4_ST_PO": 330.1079461497967,
|
||||
"P4_ST_PT01": 10047.679605000127,
|
||||
"P4_ST_TT01": 27606.860070000155
|
||||
},
|
||||
"raw_std": {
|
||||
"P1_B2004": 0.024492489898690458,
|
||||
"P1_B2016": 0.12949272564759745,
|
||||
"P1_B3004": 10.16264800653289,
|
||||
"P1_B3005": 70.85697659109,
|
||||
"P1_B4002": 0.7578213113008355,
|
||||
"P1_B4005": 41.80065314991797,
|
||||
"P1_B400B": 1176.6445547448632,
|
||||
"P1_B4022": 0.8221115066487089,
|
||||
"P1_FCV02Z": 39.11843197764177,
|
||||
"P1_FCV03D": 7.889507447726625,
|
||||
"P1_FCV03Z": 8.046068905945717,
|
||||
"P1_FT01": 30.80117031882856,
|
||||
"P1_FT01Z": 91.2786865433318,
|
||||
"P1_FT02": 879.7163277334494,
|
||||
"P1_FT02Z": 1176.6699531305117,
|
||||
"P1_FT03": 38.18015841964941,
|
||||
"P1_FT03Z": 70.73100774436428,
|
||||
"P1_LCV01D": 3.3355655415557597,
|
||||
"P1_LCV01Z": 3.386332233773545,
|
||||
"P1_LIT01": 10.578714760104122,
|
||||
"P1_PCV01D": 19.61567943613885,
|
||||
"P1_PCV01Z": 19.778754467302086,
|
||||
"P1_PCV02Z": 0.0048047978931599995,
|
||||
"P1_PIT01": 0.0776614954053113,
|
||||
"P1_PIT02": 0.44823231815652304,
|
||||
"P1_TIT01": 0.5986678527528814,
|
||||
"P1_TIT02": 1.1892341204521049,
|
||||
"P2_24Vdc": 0.003208842504097816,
|
||||
"P2_CO_rpm": 20.575477821507334,
|
||||
"P2_HILout": 8.178853379908608,
|
||||
"P2_MSD": 1.0,
|
||||
"P2_SIT01": 3.894535775667256,
|
||||
"P2_SIT02": 3.8824770788579395,
|
||||
"P2_VT01": 0.06812990916670247,
|
||||
"P2_VXT02": 0.43104157117568803,
|
||||
"P2_VXT03": 0.26894251958139775,
|
||||
"P2_VYT02": 0.4610907883207586,
|
||||
"P2_VYT03": 0.30596429385075474,
|
||||
"P3_FIT01": 1787.2987693141868,
|
||||
"P3_LCP01D": 5145.4094261812725,
|
||||
"P3_LCV01D": 6785.602781765096,
|
||||
"P3_LIT01": 4060.915441872745,
|
||||
"P3_PIT01": 1168.1071264424027,
|
||||
"P4_HT_FD": 0.002032582380617592,
|
||||
"P4_HT_LD": 33.21236116925323,
|
||||
"P4_HT_PO": 31.18782591451516,
|
||||
"P4_LD": 59.736616589045646,
|
||||
"P4_ST_FD": 0.0016428787127432496,
|
||||
"P4_ST_GOV": 1740.5997458128213,
|
||||
"P4_ST_LD": 35.86633288900077,
|
||||
"P4_ST_PO": 32.375012735256696,
|
||||
"P4_ST_PT01": 22.45996281814625,
|
||||
"P4_ST_TT01": 24.745939350221487
|
||||
},
|
||||
"min": {
|
||||
"P1_B2004": 0.03051,
|
||||
"P1_B2016": 0.94729,
|
||||
@@ -329,5 +439,115 @@
|
||||
"P4_ST_PT01": 2,
|
||||
"P4_ST_TT01": 2
|
||||
},
|
||||
"transform": {
|
||||
"P1_B2004": "none",
|
||||
"P1_B2016": "none",
|
||||
"P1_B3004": "none",
|
||||
"P1_B3005": "none",
|
||||
"P1_B4002": "none",
|
||||
"P1_B4005": "none",
|
||||
"P1_B400B": "none",
|
||||
"P1_B4022": "none",
|
||||
"P1_FCV02Z": "none",
|
||||
"P1_FCV03D": "none",
|
||||
"P1_FCV03Z": "none",
|
||||
"P1_FT01": "none",
|
||||
"P1_FT01Z": "none",
|
||||
"P1_FT02": "none",
|
||||
"P1_FT02Z": "none",
|
||||
"P1_FT03": "none",
|
||||
"P1_FT03Z": "none",
|
||||
"P1_LCV01D": "none",
|
||||
"P1_LCV01Z": "none",
|
||||
"P1_LIT01": "none",
|
||||
"P1_PCV01D": "none",
|
||||
"P1_PCV01Z": "none",
|
||||
"P1_PCV02Z": "none",
|
||||
"P1_PIT01": "none",
|
||||
"P1_PIT02": "none",
|
||||
"P1_TIT01": "none",
|
||||
"P1_TIT02": "none",
|
||||
"P2_24Vdc": "none",
|
||||
"P2_CO_rpm": "none",
|
||||
"P2_HILout": "none",
|
||||
"P2_MSD": "none",
|
||||
"P2_SIT01": "none",
|
||||
"P2_SIT02": "none",
|
||||
"P2_VT01": "none",
|
||||
"P2_VXT02": "none",
|
||||
"P2_VXT03": "none",
|
||||
"P2_VYT02": "none",
|
||||
"P2_VYT03": "none",
|
||||
"P3_FIT01": "none",
|
||||
"P3_LCP01D": "none",
|
||||
"P3_LCV01D": "none",
|
||||
"P3_LIT01": "none",
|
||||
"P3_PIT01": "none",
|
||||
"P4_HT_FD": "none",
|
||||
"P4_HT_LD": "none",
|
||||
"P4_HT_PO": "log1p",
|
||||
"P4_LD": "none",
|
||||
"P4_ST_FD": "none",
|
||||
"P4_ST_GOV": "none",
|
||||
"P4_ST_LD": "none",
|
||||
"P4_ST_PO": "none",
|
||||
"P4_ST_PT01": "none",
|
||||
"P4_ST_TT01": "none"
|
||||
},
|
||||
"skew": {
|
||||
"P1_B2004": -2.876938578031295e-05,
|
||||
"P1_B2016": 2.014565216651284e-06,
|
||||
"P1_B3004": 6.625985939357487e-06,
|
||||
"P1_B3005": -9.917489652810193e-06,
|
||||
"P1_B4002": 1.4641465884161855e-05,
|
||||
"P1_B4005": -1.2370279269006856e-05,
|
||||
"P1_B400B": -1.4116897198097317e-05,
|
||||
"P1_B4022": 1.1162291352215598e-05,
|
||||
"P1_FCV02Z": 2.532521501167817e-05,
|
||||
"P1_FCV03D": 4.2517931711793e-06,
|
||||
"P1_FCV03Z": 4.301856332440012e-06,
|
||||
"P1_FT01": -1.3345735264961829e-05,
|
||||
"P1_FT01Z": -4.2554413198354234e-05,
|
||||
"P1_FT02": -1.0289230789249066e-05,
|
||||
"P1_FT02Z": -1.4116856909216661e-05,
|
||||
"P1_FT03": -4.341090521713463e-06,
|
||||
"P1_FT03Z": -9.964308983887345e-06,
|
||||
"P1_LCV01D": 2.541312481372867e-06,
|
||||
"P1_LCV01Z": 2.5806433622267527e-06,
|
||||
"P1_LIT01": 7.716120912717401e-06,
|
||||
"P1_PCV01D": 2.113459306618771e-05,
|
||||
"P1_PCV01Z": 2.0632832525407433e-05,
|
||||
"P1_PCV02Z": 4.2639616636720384e-08,
|
||||
"P1_PIT01": 2.079887220863843e-05,
|
||||
"P1_PIT02": 5.003507344873546e-05,
|
||||
"P1_TIT01": 9.553657000925262e-06,
|
||||
"P1_TIT02": 2.1170357380515215e-05,
|
||||
"P2_24Vdc": 2.735770838906968e-07,
|
||||
"P2_CO_rpm": -8.124011608472296e-06,
|
||||
"P2_HILout": -4.086282393330704e-06,
|
||||
"P2_MSD": 0.0,
|
||||
"P2_SIT01": -7.418240348817199e-06,
|
||||
"P2_SIT02": -7.457826456660247e-06,
|
||||
"P2_VT01": 1.247484205979928e-07,
|
||||
"P2_VXT02": 6.53499778855353e-07,
|
||||
"P2_VXT03": 5.32656056809399e-06,
|
||||
"P2_VYT02": 9.483158480759724e-07,
|
||||
"P2_VYT03": 2.128755351566922e-06,
|
||||
"P3_FIT01": 2.2828575320599336e-05,
|
||||
"P3_LCP01D": 1.3040993552131866e-05,
|
||||
"P3_LCV01D": 3.781324885318626e-07,
|
||||
"P3_LIT01": -7.824733758742217e-06,
|
||||
"P3_PIT01": 3.210613447428708e-05,
|
||||
"P4_HT_FD": 9.197840236384403e-05,
|
||||
"P4_HT_LD": -2.4568845167931336e-08,
|
||||
"P4_HT_PO": 3.997415489949367e-07,
|
||||
"P4_LD": -6.253448074273654e-07,
|
||||
"P4_ST_FD": 2.3472181460829935e-07,
|
||||
"P4_ST_GOV": 2.494268873407866e-06,
|
||||
"P4_ST_LD": 1.6692758818969547e-06,
|
||||
"P4_ST_PO": 2.45129838870492e-06,
|
||||
"P4_ST_PT01": 1.7637202837434092e-05,
|
||||
"P4_ST_TT01": -1.9485876142550594e-05
|
||||
},
|
||||
"max_rows": 50000
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,6 +48,8 @@ def main():
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
data_glob = cfg.get("data_glob", "")
|
||||
data_path = cfg.get("data_path", "")
|
||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||
run(
|
||||
@@ -70,6 +72,10 @@ def main():
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
ref = data_glob if data_glob else data_path
|
||||
if ref:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||
else:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
run([sys.executable, str(base_dir / "plot_loss.py")])
|
||||
|
||||
|
||||
@@ -47,6 +47,13 @@ def main():
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
||||
model_num_layers = int(cfg.get("model_num_layers", 1))
|
||||
model_dropout = float(cfg.get("model_dropout", 0.0))
|
||||
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
||||
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
||||
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
@@ -67,6 +74,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=model_time_dim,
|
||||
hidden_dim=model_hidden_dim,
|
||||
num_layers=model_num_layers,
|
||||
dropout=model_dropout,
|
||||
ff_mult=model_ff_mult,
|
||||
pos_dim=model_pos_dim,
|
||||
use_pos_embed=model_use_pos,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
|
||||
@@ -49,8 +49,17 @@ DEFAULTS = {
|
||||
"use_condition": True,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": True,
|
||||
"use_tanh_eps": False,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": True,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +153,7 @@ def main():
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
@@ -164,6 +174,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(config.get("model_time_dim", 64)),
|
||||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||
num_layers=int(config.get("model_num_layers", 1)),
|
||||
dropout=float(config.get("model_dropout", 0.0)),
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
@@ -198,6 +215,8 @@ def main():
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=use_condition,
|
||||
transforms=transforms,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
if use_condition:
|
||||
@@ -215,7 +234,13 @@ def main():
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
x_disc,
|
||||
t,
|
||||
mask_tokens,
|
||||
int(config["timesteps"]),
|
||||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||
)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user