This commit is contained in:
MZ YANG
2026-01-23 23:40:44 +08:00
12 changed files with 1404 additions and 260 deletions

View File

@@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training. - Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. - Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
- Export now clamps continuous features to training min/max and preserves integer/decimal precision. - Export now clamps continuous features to training min/max and preserves integer/decimal precision.
- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json).
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export. - `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
- The script only samples the first 5000 rows to stay fast. - The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.

View File

@@ -6,11 +6,11 @@
"vocab_path": "./results/disc_vocab.json", "vocab_path": "./results/disc_vocab.json",
"out_dir": "./results", "out_dir": "./results",
"device": "auto", "device": "auto",
"timesteps": 400, "timesteps": 600,
"batch_size": 128, "batch_size": 128,
"seq_len": 128, "seq_len": 128,
"epochs": 8, "epochs": 10,
"max_batches": 3000, "max_batches": 4000,
"lambda": 0.5, "lambda": 0.5,
"lr": 0.0005, "lr": 0.0005,
"seed": 1337, "seed": 1337,
@@ -23,8 +23,17 @@
"use_condition": true, "use_condition": true,
"condition_type": "file_id", "condition_type": "file_id",
"cond_dim": 32, "cond_dim": 32,
"use_tanh_eps": true, "use_tanh_eps": false,
"eps_scale": 1.0, "eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128
} }

View File

@@ -4,6 +4,8 @@
import csv import csv
import gzip import gzip
import json import json
import math
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
yield row yield row
def compute_cont_stats( def _stream_basic_stats(
path: Union[str, List[str]], path: Union[str, List[str]],
cont_cols: List[str], cont_cols: List[str],
max_rows: Optional[int] = None, max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]: ):
"""Streaming mean/std (Welford) + min/max + int/precision metadata.""" """Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
count = 0 count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols}
m3 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols} vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols} vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols} int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols} max_decimals = {c: 0 for c in cont_cols}
all_pos = {c: True for c in cont_cols}
for i, row in enumerate(iter_rows(path)): for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols: for c in cont_cols:
raw = row[c] raw = row[c]
if raw is None or raw == "": if raw is None or raw == "":
continue continue
x = float(raw) x = float(raw)
delta = x - mean[c] if x <= 0:
mean[c] += delta / count all_pos[c] = False
delta2 = x - mean[c]
m2[c] += delta * delta2
if x < vmin[c]: if x < vmin[c]:
vmin[c] = x vmin[c] = x
if x > vmax[c]: if x > vmax[c]:
vmax[c] = x vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9: if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False int_like[c] = False
# track decimal places from raw string if possible if "e" not in raw and "E" not in raw and "." in raw:
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
dec = raw.split(".", 1)[1].rstrip("0") dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]: if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec) max_decimals[c] = len(dec)
n = count[c] + 1
delta = x - mean[c]
delta_n = delta / n
term1 = delta * delta_n * (n - 1)
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
m2[c] += term1
mean[c] += delta_n
count[c] = n
if max_rows is not None and i + 1 >= max_rows: if max_rows is not None and i + 1 >= max_rows:
break break
# finalize std/skew
std = {} std = {}
skew = {}
for c in cont_cols: for c in cont_cols:
if count > 1: n = count[c]
var = m2[c] / (count - 1) if n > 1:
var = m2[c] / (n - 1)
else: else:
var = 0.0 var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0 std[c] = var ** 0.5 if var > 0 else 1.0
# replace infs if column had no valid values if n > 2 and m2[c] > 0:
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
else:
skew[c] = 0.0
for c in cont_cols: for c in cont_cols:
if vmin[c] == float("inf"): if vmin[c] == float("inf"):
vmin[c] = 0.0 vmin[c] = 0.0
if vmax[c] == float("-inf"): if vmax[c] == float("-inf"):
vmax[c] = 0.0 vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
return {
"count": count,
"mean": mean,
"std": std,
"m2": m2,
"m3": m3,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"skew": skew,
"all_pos": all_pos,
}
def choose_cont_transforms(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
skew_threshold: float = 1.5,
range_ratio_threshold: float = 1e3,
):
"""Pick per-column transform (currently log1p or none) based on skew/range."""
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
transforms = {}
for c in cont_cols:
if not stats["all_pos"][c]:
transforms[c] = "none"
continue
skew = abs(stats["skew"][c])
vmin = stats["min"][c]
vmax = stats["max"][c]
ratio = (vmax / vmin) if vmin > 0 else 0.0
if skew >= skew_threshold or ratio >= range_ratio_threshold:
transforms[c] = "log1p"
else:
transforms[c] = "none"
return transforms, stats
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
# Optional transform selection
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
if raw_val is None or raw_val == "":
continue
x = float(raw_val)
if transforms.get(c) == "log1p":
if x < 0:
x = 0.0
x = math.log1p(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
delta2 = x - mean[c]
m2[c] += delta * delta2
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count[c] > 1:
var = m2[c] / (count[c] - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return {
"mean": mean,
"std": std,
"raw_mean": raw["mean"],
"raw_std": raw["std"],
"min": raw["min"],
"max": raw["max"],
"int_like": raw["int_like"],
"max_decimals": raw["max_decimals"],
"transform": transforms,
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
}
def build_vocab( def build_vocab(
@@ -130,8 +243,19 @@ def build_disc_stats(
return vocab, top_token return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]): def normalize_cont(
x,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
):
import torch import torch
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t return (x - mean_t) / std_t
@@ -148,19 +272,34 @@ def windowed_batches(
seq_len: int, seq_len: int,
max_batches: Optional[int] = None, max_batches: Optional[int] = None,
return_file_id: bool = False, return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
shuffle_buffer: int = 0,
): ):
import torch import torch
batch_cont = [] batch_cont = []
batch_disc = [] batch_disc = []
batch_file = [] batch_file = []
buffer = []
seq_cont = [] seq_cont = []
seq_disc = [] seq_disc = []
def flush_seq(): def flush_seq(file_id: int):
nonlocal seq_cont, seq_disc, batch_cont, batch_disc nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
if len(seq_cont) == seq_len: if len(seq_cont) == seq_len:
batch_cont.append(seq_cont) if shuffle_buffer and shuffle_buffer > 0:
batch_disc.append(seq_disc) buffer.append((list(seq_cont), list(seq_disc), file_id))
if len(buffer) >= shuffle_buffer:
idx = random.randrange(len(buffer))
seq_c, seq_d, seq_f = buffer.pop(idx)
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
else:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if return_file_id:
batch_file.append(file_id)
seq_cont = [] seq_cont = []
seq_disc = [] seq_disc = []
@@ -173,13 +312,11 @@ def windowed_batches(
seq_cont.append(cont_row) seq_cont.append(cont_row)
seq_disc.append(disc_row) seq_disc.append(disc_row)
if len(seq_cont) == seq_len: if len(seq_cont) == seq_len:
flush_seq() flush_seq(file_id)
if return_file_id:
batch_file.append(file_id)
if len(batch_cont) == batch_size: if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long) x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std) x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id: if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long) x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file yield x_cont, x_disc, x_file
@@ -195,4 +332,29 @@ def windowed_batches(
seq_cont = [] seq_cont = []
seq_disc = [] seq_disc = []
# Flush any remaining buffered sequences
if shuffle_buffer and buffer:
random.shuffle(buffer)
for seq_c, seq_d, seq_f in buffer:
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
if len(batch_cont) == batch_size:
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# Drop last partial batch for simplicity # Drop last partial batch for simplicity

View File

@@ -5,8 +5,9 @@ import argparse
import csv import csv
import gzip import gzip
import json import json
import math
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple, List, Optional
def load_json(path: str) -> Dict: def load_json(path: str) -> Dict:
@@ -28,6 +29,8 @@ def parse_args():
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json")) parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
return parser.parse_args() return parser.parse_args()
@@ -55,6 +58,62 @@ def finalize_stats(stats):
return out return out
def js_divergence(p, q, eps: float = 1e-12) -> float:
p = [max(x, eps) for x in p]
q = [max(x, eps) for x in q]
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
def kl(a, b):
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
def ks_statistic(x: List[float], y: List[float]) -> float:
if not x or not y:
return 0.0
x_sorted = sorted(x)
y_sorted = sorted(y)
n = len(x_sorted)
m = len(y_sorted)
i = j = 0
cdf_x = cdf_y = 0.0
d = 0.0
while i < n and j < m:
if x_sorted[i] <= y_sorted[j]:
i += 1
cdf_x = i / n
else:
j += 1
cdf_y = j / m
d = max(d, abs(cdf_x - cdf_y))
return d
def lag1_corr(values: List[float]) -> float:
if len(values) < 3:
return 0.0
x = values[:-1]
y = values[1:]
mean_x = sum(x) / len(x)
mean_y = sum(y) / len(y)
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
den_x = sum((xi - mean_x) ** 2 for xi in x)
den_y = sum((yi - mean_y) ** 2 for yi in y)
if den_x <= 0 or den_y <= 0:
return 0.0
return num / math.sqrt(den_x * den_y)
def resolve_reference_path(path: str) -> Optional[str]:
if not path:
return None
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent
pat = Path(path).name
matches = sorted(base.glob(pat))
return str(matches[0]) if matches else None
return str(path)
def main(): def main():
args = parse_args() args = parse_args()
base_dir = Path(__file__).resolve().parent base_dir = Path(__file__).resolve().parent
@@ -63,13 +122,18 @@ def main():
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
if args.reference and not Path(args.reference).is_absolute():
args.reference = str((base_dir / args.reference).resolve())
ref_path = resolve_reference_path(args.reference)
split = load_json(args.split) split = load_json(args.split)
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col] cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats_ref = load_json(args.stats)["mean"] stats_json = load_json(args.stats)
std_ref = load_json(args.stats)["std"] stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"] vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols} vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
@@ -89,6 +153,8 @@ def main():
except Exception: except Exception:
v = 0.0 v = 0.0
update_stats(cont_stats, c, v) update_stats(cont_stats, c, v)
if ref_path:
pass
for c in disc_cols: for c in disc_cols:
if row[c] not in vocab_sets[c]: if row[c] not in vocab_sets[c]:
disc_invalid[c] += 1 disc_invalid[c] += 1
@@ -112,6 +178,81 @@ def main():
"discrete_invalid_counts": disc_invalid, "discrete_invalid_counts": disc_invalid,
} }
# Optional richer metrics using reference data
if ref_path:
ref_cont = {c: [] for c in cont_cols}
ref_disc = {c: {} for c in disc_cols}
gen_cont = {c: [] for c in cont_cols}
gen_disc = {c: {} for c in disc_cols}
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
gen_cont[c].append(float(row[c]))
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
if args.max_rows and i + 1 >= args.max_rows:
break
# Continuous metrics: KS + quantiles + lag1 correlation
cont_ks = {}
cont_quant = {}
cont_lag1 = {}
for c in cont_cols:
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
ref_sorted = sorted(ref_cont[c])
gen_sorted = sorted(gen_cont[c])
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
def qval(arr, q):
if not arr:
return 0.0
idx = int(q * (len(arr) - 1))
return arr[idx]
cont_quant[c] = {
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
}
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
# Discrete metrics: JSD over vocab
disc_jsd = {}
for c in disc_cols:
vocab_vals = list(vocab_sets[c])
gen_total = sum(gen_disc[c].values()) or 1
ref_total = sum(ref_disc[c].values()) or 1
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
disc_jsd[c] = js_divergence(p, q)
report["continuous_ks"] = cont_ks
report["continuous_quantile_diff"] = cont_quant
report["continuous_lag1_diff"] = cont_lag1
report["discrete_jsd"] = disc_jsd
with open(args.out, "w", encoding="utf-8") as f: with open(args.out, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2) json.dump(report, f, indent=2)

View File

@@ -111,6 +111,7 @@ def main():
vmax = stats.get("max", {}) vmax = stats.get("max", {})
int_like = stats.get("int_like", {}) int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {}) max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"] vocab = vocab_json["vocab"]
@@ -141,6 +142,13 @@ def main():
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes, disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size if use_condition else 0, cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)), cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -220,6 +228,9 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
# clamp to observed min/max per feature # clamp to observed min/max per feature
if vmin and vmax: if vmin and vmax:
for i, c in enumerate(cont_cols): for i, c in enumerate(cont_cols):
@@ -246,8 +257,8 @@ def main():
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header: if args.include_time and time_col in header:
row[time_col] = str(row_index) row[time_col] = str(row_index)
for i, c in enumerate(cont_cols): for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i]) val = float(x_cont[b, t, i])
if int_like.get(c, False): if int_like.get(c, False):
row[c] = str(int(round(val))) row[c] = str(int(round(val)))
else: else:

View File

@@ -35,11 +35,14 @@ def q_sample_discrete(
t: torch.Tensor, t: torch.Tensor,
mask_tokens: torch.Tensor, mask_tokens: torch.Tensor,
max_t: int, max_t: int,
mask_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly mask discrete tokens with a cosine schedule over t.""" """Randomly mask discrete tokens with a cosine schedule over t."""
bsz = x0.size(0) bsz = x0.size(0)
# cosine schedule: p(0)=0, p(max_t)=1 # cosine schedule: p(0)=0, p(max_t)=1
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t))) p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
if mask_scale != 1.0:
p = torch.clamp(p * mask_scale, 0.0, 1.0)
p = p.view(bsz, 1, 1) p = p.view(bsz, 1, 1)
mask = torch.rand_like(x0.float()) < p mask = torch.rand_like(x0.float()) < p
x_masked = x0.clone() x_masked = x0.clone()
@@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module):
disc_vocab_sizes: List[int], disc_vocab_sizes: List[int],
time_dim: int = 64, time_dim: int = 64,
hidden_dim: int = 256, hidden_dim: int = 256,
num_layers: int = 1,
dropout: float = 0.0,
ff_mult: int = 2,
pos_dim: int = 64,
use_pos_embed: bool = True,
cond_vocab_size: int = 0, cond_vocab_size: int = 0,
cond_dim: int = 32, cond_dim: int = 32,
use_tanh_eps: bool = False, use_tanh_eps: bool = False,
@@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module):
self.time_embed = SinusoidalTimeEmbedding(time_dim) self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.use_tanh_eps = use_tanh_eps self.use_tanh_eps = use_tanh_eps
self.eps_scale = eps_scale self.eps_scale = eps_scale
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.cond_vocab_size = cond_vocab_size self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim self.cond_dim = cond_dim
@@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim) self.cont_proj = nn.Linear(cont_dim, cont_dim)
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0) pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim) self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) self.backbone = nn.GRU(
hidden_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.post_norm = nn.LayerNorm(hidden_dim)
self.post_ff = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * ff_mult),
nn.GELU(),
nn.Linear(hidden_dim * ff_mult, hidden_dim),
)
self.cont_head = nn.Linear(hidden_dim, cont_dim) self.cont_head = nn.Linear(hidden_dim, cont_dim)
self.disc_heads = nn.ModuleList([ self.disc_heads = nn.ModuleList([
@@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module):
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
time_emb = self.time_embed(t) time_emb = self.time_embed(t)
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
pos_emb = None
if self.use_pos_embed and self.pos_dim > 0:
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
disc_embs = [] disc_embs = []
for i, emb in enumerate(self.disc_embeds): for i, emb in enumerate(self.disc_embeds):
@@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module):
cont_feat = self.cont_proj(x_cont) cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb] parts = [cont_feat, disc_feat, time_emb]
if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None: if cond_feat is not None:
parts.append(cond_feat) parts.append(cond_feat)
feat = torch.cat(parts, dim=-1) feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat) feat = self.in_proj(feat)
out, _ = self.backbone(feat) out, _ = self.backbone(feat)
out = self.post_norm(out)
out = out + self.post_ff(out)
eps_pred = self.cont_head(out) eps_pred = self.cont_head(out)
if self.use_tanh_eps: if self.use_tanh_eps:
eps_pred = torch.tanh(eps_pred) * self.eps_scale eps_pred = torch.tanh(eps_pred) * self.eps_scale
logits = [head(out) for head in self.disc_heads] logits = [head(out) for head in self.disc_heads]
return eps_pred, logits return eps_pred, logits
@staticmethod
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe

View File

@@ -5,7 +5,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from data_utils import compute_cont_stats, build_disc_stats, load_split from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
from platform_utils import safe_path, ensure_dir from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None):
raise SystemExit("no train files found under %s" % str(DATA_GLOB)) raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths] data_paths = [safe_path(p) for p in data_paths]
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent) ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
json.dump( json.dump(
{ {
"mean": mean, "mean": cont_stats["mean"],
"std": std, "std": cont_stats["std"],
"min": vmin, "raw_mean": cont_stats["raw_mean"],
"max": vmax, "raw_std": cont_stats["raw_std"],
"int_like": int_like, "min": cont_stats["min"],
"max_decimals": max_decimals, "max": cont_stats["max"],
"max_rows": max_rows, "int_like": cont_stats["int_like"],
"max_decimals": cont_stats["max_decimals"],
"transform": cont_stats["transform"],
"skew": cont_stats["skew"],
"max_rows": cont_stats["max_rows"],
}, },
f, f,
indent=2, indent=2,

View File

@@ -45,7 +45,7 @@
"P3_PIT01": 668.9722350000003, "P3_PIT01": 668.9722350000003,
"P4_HT_FD": -0.00010012580000000082, "P4_HT_FD": -0.00010012580000000082,
"P4_HT_LD": 35.41945000099953, "P4_HT_LD": 35.41945000099953,
"P4_HT_PO": 35.4085699912002, "P4_HT_PO": 2.6391372939040414,
"P4_LD": 365.3833745803986, "P4_LD": 365.3833745803986,
"P4_ST_FD": -6.5205999999999635e-06, "P4_ST_FD": -6.5205999999999635e-06,
"P4_ST_GOV": 17801.81294499996, "P4_ST_GOV": 17801.81294499996,
@@ -100,7 +100,7 @@
"P3_PIT01": 1168.1071264424027, "P3_PIT01": 1168.1071264424027,
"P4_HT_FD": 0.002032582380617592, "P4_HT_FD": 0.002032582380617592,
"P4_HT_LD": 33.212361169253235, "P4_HT_LD": 33.212361169253235,
"P4_HT_PO": 31.187825914515162, "P4_HT_PO": 1.7636196192459512,
"P4_LD": 59.736616589045646, "P4_LD": 59.736616589045646,
"P4_ST_FD": 0.0016428787127432496, "P4_ST_FD": 0.0016428787127432496,
"P4_ST_GOV": 1740.5997458128215, "P4_ST_GOV": 1740.5997458128215,
@@ -109,6 +109,116 @@
"P4_ST_PT01": 22.459962818146252, "P4_ST_PT01": 22.459962818146252,
"P4_ST_TT01": 24.745939350221477 "P4_ST_TT01": 24.745939350221477
}, },
"raw_mean": {
"P1_B2004": 0.08649086820000026,
"P1_B2016": 1.376161456000001,
"P1_B3004": 396.1861596906018,
"P1_B3005": 1037.372384413793,
"P1_B4002": 32.564872940799994,
"P1_B4005": 65.98190757240047,
"P1_B400B": 1925.0391570245934,
"P1_B4022": 36.28908066800001,
"P1_FCV02Z": 21.744261118400036,
"P1_FCV03D": 57.36123274140044,
"P1_FCV03Z": 58.05084519640002,
"P1_FT01": 184.18615112319728,
"P1_FT01Z": 851.8781750705965,
"P1_FT02": 1255.8572173544069,
"P1_FT02Z": 1925.0210755194114,
"P1_FT03": 269.37285885780574,
"P1_FT03Z": 1037.366172230601,
"P1_LCV01D": 11.228849048599963,
"P1_LCV01Z": 10.991610181600016,
"P1_LIT01": 396.8845311109994,
"P1_PCV01D": 53.80101618419986,
"P1_PCV01Z": 54.646640287199595,
"P1_PCV02Z": 12.017773542800072,
"P1_PIT01": 1.3692859488000075,
"P1_PIT02": 0.44459071260000227,
"P1_TIT01": 35.64255813999988,
"P1_TIT02": 36.44807823060023,
"P2_24Vdc": 28.0280019013999,
"P2_CO_rpm": 54105.64434999997,
"P2_HILout": 712.0588667425922,
"P2_MSD": 763.19324,
"P2_SIT01": 778.7769850000013,
"P2_SIT02": 778.7778935471981,
"P2_VT01": 11.914949448200044,
"P2_VXT02": -3.5267871940000175,
"P2_VXT03": -1.5520904921999914,
"P2_VYT02": 3.796112737600002,
"P2_VYT03": 6.121691697000018,
"P3_FIT01": 1168.2528800000014,
"P3_LCP01D": 4675.465239999989,
"P3_LCV01D": 7445.208720000017,
"P3_LIT01": 13728.982314999852,
"P3_PIT01": 668.9722350000003,
"P4_HT_FD": -0.00010012580000000082,
"P4_HT_LD": 35.41945000099953,
"P4_HT_PO": 35.4085699912002,
"P4_LD": 365.3833745803986,
"P4_ST_FD": -6.5205999999999635e-06,
"P4_ST_GOV": 17801.81294499996,
"P4_ST_LD": 329.83259218199964,
"P4_ST_PO": 330.1079461497967,
"P4_ST_PT01": 10047.679605000127,
"P4_ST_TT01": 27606.860070000155
},
"raw_std": {
"P1_B2004": 0.024492489898690458,
"P1_B2016": 0.12949272564759745,
"P1_B3004": 10.16264800653289,
"P1_B3005": 70.85697659109,
"P1_B4002": 0.7578213113008355,
"P1_B4005": 41.80065314991797,
"P1_B400B": 1176.6445547448632,
"P1_B4022": 0.8221115066487089,
"P1_FCV02Z": 39.11843197764177,
"P1_FCV03D": 7.889507447726625,
"P1_FCV03Z": 8.046068905945717,
"P1_FT01": 30.80117031882856,
"P1_FT01Z": 91.2786865433318,
"P1_FT02": 879.7163277334494,
"P1_FT02Z": 1176.6699531305117,
"P1_FT03": 38.18015841964941,
"P1_FT03Z": 70.73100774436428,
"P1_LCV01D": 3.3355655415557597,
"P1_LCV01Z": 3.386332233773545,
"P1_LIT01": 10.578714760104122,
"P1_PCV01D": 19.61567943613885,
"P1_PCV01Z": 19.778754467302086,
"P1_PCV02Z": 0.0048047978931599995,
"P1_PIT01": 0.0776614954053113,
"P1_PIT02": 0.44823231815652304,
"P1_TIT01": 0.5986678527528814,
"P1_TIT02": 1.1892341204521049,
"P2_24Vdc": 0.003208842504097816,
"P2_CO_rpm": 20.575477821507334,
"P2_HILout": 8.178853379908608,
"P2_MSD": 1.0,
"P2_SIT01": 3.894535775667256,
"P2_SIT02": 3.8824770788579395,
"P2_VT01": 0.06812990916670247,
"P2_VXT02": 0.43104157117568803,
"P2_VXT03": 0.26894251958139775,
"P2_VYT02": 0.4610907883207586,
"P2_VYT03": 0.30596429385075474,
"P3_FIT01": 1787.2987693141868,
"P3_LCP01D": 5145.4094261812725,
"P3_LCV01D": 6785.602781765096,
"P3_LIT01": 4060.915441872745,
"P3_PIT01": 1168.1071264424027,
"P4_HT_FD": 0.002032582380617592,
"P4_HT_LD": 33.21236116925323,
"P4_HT_PO": 31.18782591451516,
"P4_LD": 59.736616589045646,
"P4_ST_FD": 0.0016428787127432496,
"P4_ST_GOV": 1740.5997458128213,
"P4_ST_LD": 35.86633288900077,
"P4_ST_PO": 32.375012735256696,
"P4_ST_PT01": 22.45996281814625,
"P4_ST_TT01": 24.745939350221487
},
"min": { "min": {
"P1_B2004": 0.03051, "P1_B2004": 0.03051,
"P1_B2016": 0.94729, "P1_B2016": 0.94729,
@@ -329,5 +439,115 @@
"P4_ST_PT01": 2, "P4_ST_PT01": 2,
"P4_ST_TT01": 2 "P4_ST_TT01": 2
}, },
"transform": {
"P1_B2004": "none",
"P1_B2016": "none",
"P1_B3004": "none",
"P1_B3005": "none",
"P1_B4002": "none",
"P1_B4005": "none",
"P1_B400B": "none",
"P1_B4022": "none",
"P1_FCV02Z": "none",
"P1_FCV03D": "none",
"P1_FCV03Z": "none",
"P1_FT01": "none",
"P1_FT01Z": "none",
"P1_FT02": "none",
"P1_FT02Z": "none",
"P1_FT03": "none",
"P1_FT03Z": "none",
"P1_LCV01D": "none",
"P1_LCV01Z": "none",
"P1_LIT01": "none",
"P1_PCV01D": "none",
"P1_PCV01Z": "none",
"P1_PCV02Z": "none",
"P1_PIT01": "none",
"P1_PIT02": "none",
"P1_TIT01": "none",
"P1_TIT02": "none",
"P2_24Vdc": "none",
"P2_CO_rpm": "none",
"P2_HILout": "none",
"P2_MSD": "none",
"P2_SIT01": "none",
"P2_SIT02": "none",
"P2_VT01": "none",
"P2_VXT02": "none",
"P2_VXT03": "none",
"P2_VYT02": "none",
"P2_VYT03": "none",
"P3_FIT01": "none",
"P3_LCP01D": "none",
"P3_LCV01D": "none",
"P3_LIT01": "none",
"P3_PIT01": "none",
"P4_HT_FD": "none",
"P4_HT_LD": "none",
"P4_HT_PO": "log1p",
"P4_LD": "none",
"P4_ST_FD": "none",
"P4_ST_GOV": "none",
"P4_ST_LD": "none",
"P4_ST_PO": "none",
"P4_ST_PT01": "none",
"P4_ST_TT01": "none"
},
"skew": {
"P1_B2004": -2.876938578031295e-05,
"P1_B2016": 2.014565216651284e-06,
"P1_B3004": 6.625985939357487e-06,
"P1_B3005": -9.917489652810193e-06,
"P1_B4002": 1.4641465884161855e-05,
"P1_B4005": -1.2370279269006856e-05,
"P1_B400B": -1.4116897198097317e-05,
"P1_B4022": 1.1162291352215598e-05,
"P1_FCV02Z": 2.532521501167817e-05,
"P1_FCV03D": 4.2517931711793e-06,
"P1_FCV03Z": 4.301856332440012e-06,
"P1_FT01": -1.3345735264961829e-05,
"P1_FT01Z": -4.2554413198354234e-05,
"P1_FT02": -1.0289230789249066e-05,
"P1_FT02Z": -1.4116856909216661e-05,
"P1_FT03": -4.341090521713463e-06,
"P1_FT03Z": -9.964308983887345e-06,
"P1_LCV01D": 2.541312481372867e-06,
"P1_LCV01Z": 2.5806433622267527e-06,
"P1_LIT01": 7.716120912717401e-06,
"P1_PCV01D": 2.113459306618771e-05,
"P1_PCV01Z": 2.0632832525407433e-05,
"P1_PCV02Z": 4.2639616636720384e-08,
"P1_PIT01": 2.079887220863843e-05,
"P1_PIT02": 5.003507344873546e-05,
"P1_TIT01": 9.553657000925262e-06,
"P1_TIT02": 2.1170357380515215e-05,
"P2_24Vdc": 2.735770838906968e-07,
"P2_CO_rpm": -8.124011608472296e-06,
"P2_HILout": -4.086282393330704e-06,
"P2_MSD": 0.0,
"P2_SIT01": -7.418240348817199e-06,
"P2_SIT02": -7.457826456660247e-06,
"P2_VT01": 1.247484205979928e-07,
"P2_VXT02": 6.53499778855353e-07,
"P2_VXT03": 5.32656056809399e-06,
"P2_VYT02": 9.483158480759724e-07,
"P2_VYT03": 2.128755351566922e-06,
"P3_FIT01": 2.2828575320599336e-05,
"P3_LCP01D": 1.3040993552131866e-05,
"P3_LCV01D": 3.781324885318626e-07,
"P3_LIT01": -7.824733758742217e-06,
"P3_PIT01": 3.210613447428708e-05,
"P4_HT_FD": 9.197840236384403e-05,
"P4_HT_LD": -2.4568845167931336e-08,
"P4_HT_PO": 3.997415489949367e-07,
"P4_LD": -6.253448074273654e-07,
"P4_ST_FD": 2.3472181460829935e-07,
"P4_ST_GOV": 2.494268873407866e-06,
"P4_ST_LD": 1.6692758818969547e-06,
"P4_ST_PO": 2.45129838870492e-06,
"P4_ST_PT01": 1.7637202837434092e-05,
"P4_ST_TT01": -1.9485876142550594e-05
},
"max_rows": 50000 "max_rows": 50000
} }

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,8 @@ def main():
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
clip_k = cfg.get("clip_k", 5.0) clip_k = cfg.get("clip_k", 5.0)
data_glob = cfg.get("data_glob", "")
data_path = cfg.get("data_path", "")
run([sys.executable, str(base_dir / "prepare_data.py")]) run([sys.executable, str(base_dir / "prepare_data.py")])
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device]) run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
run( run(
@@ -70,7 +72,11 @@ def main():
"--use-ema", "--use-ema",
] ]
) )
run([sys.executable, str(base_dir / "evaluate_generated.py")]) ref = data_glob if data_glob else data_path
if ref:
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else:
run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "plot_loss.py")]) run([sys.executable, str(base_dir / "plot_loss.py")])

View File

@@ -47,6 +47,13 @@ def main():
cond_dim = int(cfg.get("cond_dim", 32)) cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0)) eps_scale = float(cfg.get("eps_scale", 1.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
model_num_layers = int(cfg.get("model_num_layers", 1))
model_dropout = float(cfg.get("model_dropout", 0.0))
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
split = load_split(str(SPLIT_PATH)) split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
@@ -67,6 +74,13 @@ def main():
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes, disc_vocab_sizes=vocab_sizes,
time_dim=model_time_dim,
hidden_dim=model_hidden_dim,
num_layers=model_num_layers,
dropout=model_dropout,
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim, cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps, use_tanh_eps=use_tanh_eps,

View File

@@ -49,8 +49,17 @@ DEFAULTS = {
"use_condition": True, "use_condition": True,
"condition_type": "file_id", "condition_type": "file_id",
"cond_dim": 32, "cond_dim": 32,
"use_tanh_eps": True, "use_tanh_eps": False,
"eps_scale": 1.0, "eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
} }
@@ -144,6 +153,7 @@ def main():
stats = load_json(config["stats_path"]) stats = load_json(config["stats_path"])
mean = stats["mean"] mean = stats["mean"]
std = stats["std"] std = stats["std"]
transforms = stats.get("transform", {})
vocab = load_json(config["vocab_path"])["vocab"] vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols] vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -164,6 +174,13 @@ def main():
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes, disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)), cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)), use_tanh_eps=bool(config.get("use_tanh_eps", False)),
@@ -198,6 +215,8 @@ def main():
seq_len=int(config["seq_len"]), seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]), max_batches=int(config["max_batches"]),
return_file_id=use_condition, return_file_id=use_condition,
transforms=transforms,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
) )
): ):
if use_condition: if use_condition:
@@ -215,7 +234,13 @@ def main():
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device) mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)