This commit is contained in:
MZ YANG
2026-01-23 23:40:44 +08:00
12 changed files with 1404 additions and 260 deletions

View File

@@ -67,6 +67,7 @@ python example/run_pipeline.py --device auto
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
- Continuous features may be log1p-transformed automatically for heavy-tailed columns (see cont_stats.json).
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
- The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.

View File

@@ -6,11 +6,11 @@
"vocab_path": "./results/disc_vocab.json",
"out_dir": "./results",
"device": "auto",
"timesteps": 400,
"timesteps": 600,
"batch_size": 128,
"seq_len": 128,
"epochs": 8,
"max_batches": 3000,
"epochs": 10,
"max_batches": 4000,
"lambda": 0.5,
"lr": 0.0005,
"seed": 1337,
@@ -23,8 +23,17 @@
"use_condition": true,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": true,
"use_tanh_eps": false,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"sample_batch_size": 8,
"sample_seq_len": 128
}

View File

@@ -4,6 +4,8 @@
import csv
import gzip
import json
import math
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -23,62 +25,173 @@ def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
yield row
def compute_cont_stats(
def _stream_basic_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
count = 0
):
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
m3 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
all_pos = {c: True for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols:
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
delta = x - mean[c]
mean[c] += delta / count
delta2 = x - mean[c]
m2[c] += delta * delta2
if x <= 0:
all_pos[c] = False
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
# track decimal places from raw string if possible
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
if "e" not in raw and "E" not in raw and "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
n = count[c] + 1
delta = x - mean[c]
delta_n = delta / n
term1 = delta * delta_n * (n - 1)
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
m2[c] += term1
mean[c] += delta_n
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
# finalize std/skew
std = {}
skew = {}
for c in cont_cols:
if count > 1:
var = m2[c] / (count - 1)
n = count[c]
if n > 1:
var = m2[c] / (n - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
# replace infs if column had no valid values
if n > 2 and m2[c] > 0:
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
else:
skew[c] = 0.0
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
return {
"count": count,
"mean": mean,
"std": std,
"m2": m2,
"m3": m3,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"skew": skew,
"all_pos": all_pos,
}
def choose_cont_transforms(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
skew_threshold: float = 1.5,
range_ratio_threshold: float = 1e3,
):
"""Pick per-column transform (currently log1p or none) based on skew/range."""
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
transforms = {}
for c in cont_cols:
if not stats["all_pos"][c]:
transforms[c] = "none"
continue
skew = abs(stats["skew"][c])
vmin = stats["min"][c]
vmax = stats["max"][c]
ratio = (vmax / vmin) if vmin > 0 else 0.0
if skew >= skew_threshold or ratio >= range_ratio_threshold:
transforms[c] = "log1p"
else:
transforms[c] = "none"
return transforms, stats
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
# Optional transform selection
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
if raw_val is None or raw_val == "":
continue
x = float(raw_val)
if transforms.get(c) == "log1p":
if x < 0:
x = 0.0
x = math.log1p(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
delta2 = x - mean[c]
m2[c] += delta * delta2
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count[c] > 1:
var = m2[c] / (count[c] - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return {
"mean": mean,
"std": std,
"raw_mean": raw["mean"],
"raw_std": raw["std"],
"min": raw["min"],
"max": raw["max"],
"int_like": raw["int_like"],
"max_decimals": raw["max_decimals"],
"transform": transforms,
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
}
def build_vocab(
@@ -130,8 +243,19 @@ def build_disc_stats(
return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
def normalize_cont(
x,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
):
import torch
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
@@ -148,19 +272,34 @@ def windowed_batches(
seq_len: int,
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
shuffle_buffer: int = 0,
):
import torch
batch_cont = []
batch_disc = []
batch_file = []
buffer = []
seq_cont = []
seq_disc = []
def flush_seq():
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
def flush_seq(file_id: int):
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
if len(seq_cont) == seq_len:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if shuffle_buffer and shuffle_buffer > 0:
buffer.append((list(seq_cont), list(seq_disc), file_id))
if len(buffer) >= shuffle_buffer:
idx = random.randrange(len(buffer))
seq_c, seq_d, seq_f = buffer.pop(idx)
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
else:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if return_file_id:
batch_file.append(file_id)
seq_cont = []
seq_disc = []
@@ -173,13 +312,11 @@ def windowed_batches(
seq_cont.append(cont_row)
seq_disc.append(disc_row)
if len(seq_cont) == seq_len:
flush_seq()
if return_file_id:
batch_file.append(file_id)
flush_seq(file_id)
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
@@ -195,4 +332,29 @@ def windowed_batches(
seq_cont = []
seq_disc = []
# Flush any remaining buffered sequences
if shuffle_buffer and buffer:
random.shuffle(buffer)
for seq_c, seq_d, seq_f in buffer:
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
if len(batch_cont) == batch_size:
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# Drop last partial batch for simplicity

View File

@@ -5,8 +5,9 @@ import argparse
import csv
import gzip
import json
import math
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict, Tuple, List, Optional
def load_json(path: str) -> Dict:
@@ -28,6 +29,8 @@ def parse_args():
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--reference", default="", help="Optional reference CSV (train) for richer metrics")
parser.add_argument("--max-rows", type=int, default=20000, help="Max rows to load for reference metrics")
return parser.parse_args()
@@ -55,6 +58,62 @@ def finalize_stats(stats):
return out
def js_divergence(p, q, eps: float = 1e-12) -> float:
p = [max(x, eps) for x in p]
q = [max(x, eps) for x in q]
m = [(pi + qi) / 2.0 for pi, qi in zip(p, q)]
def kl(a, b):
return sum(ai * math.log(ai / bi, 2) for ai, bi in zip(a, b))
return 0.5 * kl(p, m) + 0.5 * kl(q, m)
def ks_statistic(x: List[float], y: List[float]) -> float:
if not x or not y:
return 0.0
x_sorted = sorted(x)
y_sorted = sorted(y)
n = len(x_sorted)
m = len(y_sorted)
i = j = 0
cdf_x = cdf_y = 0.0
d = 0.0
while i < n and j < m:
if x_sorted[i] <= y_sorted[j]:
i += 1
cdf_x = i / n
else:
j += 1
cdf_y = j / m
d = max(d, abs(cdf_x - cdf_y))
return d
def lag1_corr(values: List[float]) -> float:
if len(values) < 3:
return 0.0
x = values[:-1]
y = values[1:]
mean_x = sum(x) / len(x)
mean_y = sum(y) / len(y)
num = sum((xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y))
den_x = sum((xi - mean_x) ** 2 for xi in x)
den_y = sum((yi - mean_y) ** 2 for yi in y)
if den_x <= 0 or den_y <= 0:
return 0.0
return num / math.sqrt(den_x * den_y)
def resolve_reference_path(path: str) -> Optional[str]:
if not path:
return None
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent
pat = Path(path).name
matches = sorted(base.glob(pat))
return str(matches[0]) if matches else None
return str(path)
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
@@ -63,13 +122,18 @@ def main():
args.stats = str((base_dir / args.stats).resolve()) if not Path(args.stats).is_absolute() else args.stats
args.vocab = str((base_dir / args.vocab).resolve()) if not Path(args.vocab).is_absolute() else args.vocab
args.out = str((base_dir / args.out).resolve()) if not Path(args.out).is_absolute() else args.out
if args.reference and not Path(args.reference).is_absolute():
args.reference = str((base_dir / args.reference).resolve())
ref_path = resolve_reference_path(args.reference)
split = load_json(args.split)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats_ref = load_json(args.stats)["mean"]
std_ref = load_json(args.stats)["std"]
stats_json = load_json(args.stats)
stats_ref = stats_json.get("raw_mean", stats_json.get("mean"))
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
@@ -89,6 +153,8 @@ def main():
except Exception:
v = 0.0
update_stats(cont_stats, c, v)
if ref_path:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
disc_invalid[c] += 1
@@ -112,6 +178,81 @@ def main():
"discrete_invalid_counts": disc_invalid,
}
# Optional richer metrics using reference data
if ref_path:
ref_cont = {c: [] for c in cont_cols}
ref_disc = {c: {} for c in disc_cols}
gen_cont = {c: [] for c in cont_cols}
gen_disc = {c: {} for c in disc_cols}
with open_csv(args.generated) as f:
reader = csv.DictReader(f)
for row in reader:
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
gen_cont[c].append(float(row[c]))
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
with open_csv(ref_path) as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if time_col in row:
row.pop(time_col, None)
for c in cont_cols:
try:
ref_cont[c].append(float(row[c]))
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
if args.max_rows and i + 1 >= args.max_rows:
break
# Continuous metrics: KS + quantiles + lag1 correlation
cont_ks = {}
cont_quant = {}
cont_lag1 = {}
for c in cont_cols:
cont_ks[c] = ks_statistic(gen_cont[c], ref_cont[c])
ref_sorted = sorted(ref_cont[c])
gen_sorted = sorted(gen_cont[c])
qs = [0.05, 0.25, 0.5, 0.75, 0.95]
def qval(arr, q):
if not arr:
return 0.0
idx = int(q * (len(arr) - 1))
return arr[idx]
cont_quant[c] = {
"q05_diff": abs(qval(gen_sorted, 0.05) - qval(ref_sorted, 0.05)),
"q25_diff": abs(qval(gen_sorted, 0.25) - qval(ref_sorted, 0.25)),
"q50_diff": abs(qval(gen_sorted, 0.5) - qval(ref_sorted, 0.5)),
"q75_diff": abs(qval(gen_sorted, 0.75) - qval(ref_sorted, 0.75)),
"q95_diff": abs(qval(gen_sorted, 0.95) - qval(ref_sorted, 0.95)),
}
cont_lag1[c] = abs(lag1_corr(gen_cont[c]) - lag1_corr(ref_cont[c]))
# Discrete metrics: JSD over vocab
disc_jsd = {}
for c in disc_cols:
vocab_vals = list(vocab_sets[c])
gen_total = sum(gen_disc[c].values()) or 1
ref_total = sum(ref_disc[c].values()) or 1
p = [gen_disc[c].get(v, 0) / gen_total for v in vocab_vals]
q = [ref_disc[c].get(v, 0) / ref_total for v in vocab_vals]
disc_jsd[c] = js_divergence(p, q)
report["continuous_ks"] = cont_ks
report["continuous_quantile_diff"] = cont_quant
report["continuous_lag1_diff"] = cont_lag1
report["discrete_jsd"] = disc_jsd
with open(args.out, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)

View File

@@ -111,6 +111,7 @@ def main():
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
@@ -141,6 +142,13 @@ def main():
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -220,6 +228,9 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
# clamp to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
@@ -246,8 +257,8 @@ def main():
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:

View File

@@ -35,11 +35,14 @@ def q_sample_discrete(
t: torch.Tensor,
mask_tokens: torch.Tensor,
max_t: int,
mask_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly mask discrete tokens with a cosine schedule over t."""
bsz = x0.size(0)
# cosine schedule: p(0)=0, p(max_t)=1
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
if mask_scale != 1.0:
p = torch.clamp(p * mask_scale, 0.0, 1.0)
p = p.view(bsz, 1, 1)
mask = torch.rand_like(x0.float()) < p
x_masked = x0.clone()
@@ -70,6 +73,11 @@ class HybridDiffusionModel(nn.Module):
disc_vocab_sizes: List[int],
time_dim: int = 64,
hidden_dim: int = 256,
num_layers: int = 1,
dropout: float = 0.0,
ff_mult: int = 2,
pos_dim: int = 64,
use_pos_embed: bool = True,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
@@ -82,6 +90,8 @@ class HybridDiffusionModel(nn.Module):
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.use_tanh_eps = use_tanh_eps
self.eps_scale = eps_scale
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
@@ -96,9 +106,22 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim)
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.backbone = nn.GRU(
hidden_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.post_norm = nn.LayerNorm(hidden_dim)
self.post_ff = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * ff_mult),
nn.GELU(),
nn.Linear(hidden_dim * ff_mult, hidden_dim),
)
self.cont_head = nn.Linear(hidden_dim, cont_dim)
self.disc_heads = nn.ModuleList([
@@ -110,6 +133,9 @@ class HybridDiffusionModel(nn.Module):
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
time_emb = self.time_embed(t)
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
pos_emb = None
if self.use_pos_embed and self.pos_dim > 0:
pos_emb = self._positional_encoding(x_cont.size(1), self.pos_dim, x_cont.device)
disc_embs = []
for i, emb in enumerate(self.disc_embeds):
@@ -124,15 +150,28 @@ class HybridDiffusionModel(nn.Module):
cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb]
if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None:
parts.append(cond_feat)
feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat)
out, _ = self.backbone(feat)
out = self.post_norm(out)
out = out + self.post_ff(out)
eps_pred = self.cont_head(out)
if self.use_tanh_eps:
eps_pred = torch.tanh(eps_pred) * self.eps_scale
logits = [head(out) for head in self.disc_heads]
return eps_pred, logits
@staticmethod
def _positional_encoding(seq_len: int, dim: int, device: torch.device) -> torch.Tensor:
pos = torch.arange(seq_len, device=device).float().unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe

View File

@@ -5,7 +5,7 @@ import json
from pathlib import Path
from typing import Optional
from data_utils import compute_cont_stats, build_disc_stats, load_split
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None):
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths]
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
json.dump(
{
"mean": mean,
"std": std,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"max_rows": max_rows,
"mean": cont_stats["mean"],
"std": cont_stats["std"],
"raw_mean": cont_stats["raw_mean"],
"raw_std": cont_stats["raw_std"],
"min": cont_stats["min"],
"max": cont_stats["max"],
"int_like": cont_stats["int_like"],
"max_decimals": cont_stats["max_decimals"],
"transform": cont_stats["transform"],
"skew": cont_stats["skew"],
"max_rows": cont_stats["max_rows"],
},
f,
indent=2,

View File

@@ -45,7 +45,7 @@
"P3_PIT01": 668.9722350000003,
"P4_HT_FD": -0.00010012580000000082,
"P4_HT_LD": 35.41945000099953,
"P4_HT_PO": 35.4085699912002,
"P4_HT_PO": 2.6391372939040414,
"P4_LD": 365.3833745803986,
"P4_ST_FD": -6.5205999999999635e-06,
"P4_ST_GOV": 17801.81294499996,
@@ -100,7 +100,7 @@
"P3_PIT01": 1168.1071264424027,
"P4_HT_FD": 0.002032582380617592,
"P4_HT_LD": 33.212361169253235,
"P4_HT_PO": 31.187825914515162,
"P4_HT_PO": 1.7636196192459512,
"P4_LD": 59.736616589045646,
"P4_ST_FD": 0.0016428787127432496,
"P4_ST_GOV": 1740.5997458128215,
@@ -109,6 +109,116 @@
"P4_ST_PT01": 22.459962818146252,
"P4_ST_TT01": 24.745939350221477
},
"raw_mean": {
"P1_B2004": 0.08649086820000026,
"P1_B2016": 1.376161456000001,
"P1_B3004": 396.1861596906018,
"P1_B3005": 1037.372384413793,
"P1_B4002": 32.564872940799994,
"P1_B4005": 65.98190757240047,
"P1_B400B": 1925.0391570245934,
"P1_B4022": 36.28908066800001,
"P1_FCV02Z": 21.744261118400036,
"P1_FCV03D": 57.36123274140044,
"P1_FCV03Z": 58.05084519640002,
"P1_FT01": 184.18615112319728,
"P1_FT01Z": 851.8781750705965,
"P1_FT02": 1255.8572173544069,
"P1_FT02Z": 1925.0210755194114,
"P1_FT03": 269.37285885780574,
"P1_FT03Z": 1037.366172230601,
"P1_LCV01D": 11.228849048599963,
"P1_LCV01Z": 10.991610181600016,
"P1_LIT01": 396.8845311109994,
"P1_PCV01D": 53.80101618419986,
"P1_PCV01Z": 54.646640287199595,
"P1_PCV02Z": 12.017773542800072,
"P1_PIT01": 1.3692859488000075,
"P1_PIT02": 0.44459071260000227,
"P1_TIT01": 35.64255813999988,
"P1_TIT02": 36.44807823060023,
"P2_24Vdc": 28.0280019013999,
"P2_CO_rpm": 54105.64434999997,
"P2_HILout": 712.0588667425922,
"P2_MSD": 763.19324,
"P2_SIT01": 778.7769850000013,
"P2_SIT02": 778.7778935471981,
"P2_VT01": 11.914949448200044,
"P2_VXT02": -3.5267871940000175,
"P2_VXT03": -1.5520904921999914,
"P2_VYT02": 3.796112737600002,
"P2_VYT03": 6.121691697000018,
"P3_FIT01": 1168.2528800000014,
"P3_LCP01D": 4675.465239999989,
"P3_LCV01D": 7445.208720000017,
"P3_LIT01": 13728.982314999852,
"P3_PIT01": 668.9722350000003,
"P4_HT_FD": -0.00010012580000000082,
"P4_HT_LD": 35.41945000099953,
"P4_HT_PO": 35.4085699912002,
"P4_LD": 365.3833745803986,
"P4_ST_FD": -6.5205999999999635e-06,
"P4_ST_GOV": 17801.81294499996,
"P4_ST_LD": 329.83259218199964,
"P4_ST_PO": 330.1079461497967,
"P4_ST_PT01": 10047.679605000127,
"P4_ST_TT01": 27606.860070000155
},
"raw_std": {
"P1_B2004": 0.024492489898690458,
"P1_B2016": 0.12949272564759745,
"P1_B3004": 10.16264800653289,
"P1_B3005": 70.85697659109,
"P1_B4002": 0.7578213113008355,
"P1_B4005": 41.80065314991797,
"P1_B400B": 1176.6445547448632,
"P1_B4022": 0.8221115066487089,
"P1_FCV02Z": 39.11843197764177,
"P1_FCV03D": 7.889507447726625,
"P1_FCV03Z": 8.046068905945717,
"P1_FT01": 30.80117031882856,
"P1_FT01Z": 91.2786865433318,
"P1_FT02": 879.7163277334494,
"P1_FT02Z": 1176.6699531305117,
"P1_FT03": 38.18015841964941,
"P1_FT03Z": 70.73100774436428,
"P1_LCV01D": 3.3355655415557597,
"P1_LCV01Z": 3.386332233773545,
"P1_LIT01": 10.578714760104122,
"P1_PCV01D": 19.61567943613885,
"P1_PCV01Z": 19.778754467302086,
"P1_PCV02Z": 0.0048047978931599995,
"P1_PIT01": 0.0776614954053113,
"P1_PIT02": 0.44823231815652304,
"P1_TIT01": 0.5986678527528814,
"P1_TIT02": 1.1892341204521049,
"P2_24Vdc": 0.003208842504097816,
"P2_CO_rpm": 20.575477821507334,
"P2_HILout": 8.178853379908608,
"P2_MSD": 1.0,
"P2_SIT01": 3.894535775667256,
"P2_SIT02": 3.8824770788579395,
"P2_VT01": 0.06812990916670247,
"P2_VXT02": 0.43104157117568803,
"P2_VXT03": 0.26894251958139775,
"P2_VYT02": 0.4610907883207586,
"P2_VYT03": 0.30596429385075474,
"P3_FIT01": 1787.2987693141868,
"P3_LCP01D": 5145.4094261812725,
"P3_LCV01D": 6785.602781765096,
"P3_LIT01": 4060.915441872745,
"P3_PIT01": 1168.1071264424027,
"P4_HT_FD": 0.002032582380617592,
"P4_HT_LD": 33.21236116925323,
"P4_HT_PO": 31.18782591451516,
"P4_LD": 59.736616589045646,
"P4_ST_FD": 0.0016428787127432496,
"P4_ST_GOV": 1740.5997458128213,
"P4_ST_LD": 35.86633288900077,
"P4_ST_PO": 32.375012735256696,
"P4_ST_PT01": 22.45996281814625,
"P4_ST_TT01": 24.745939350221487
},
"min": {
"P1_B2004": 0.03051,
"P1_B2016": 0.94729,
@@ -329,5 +439,115 @@
"P4_ST_PT01": 2,
"P4_ST_TT01": 2
},
"transform": {
"P1_B2004": "none",
"P1_B2016": "none",
"P1_B3004": "none",
"P1_B3005": "none",
"P1_B4002": "none",
"P1_B4005": "none",
"P1_B400B": "none",
"P1_B4022": "none",
"P1_FCV02Z": "none",
"P1_FCV03D": "none",
"P1_FCV03Z": "none",
"P1_FT01": "none",
"P1_FT01Z": "none",
"P1_FT02": "none",
"P1_FT02Z": "none",
"P1_FT03": "none",
"P1_FT03Z": "none",
"P1_LCV01D": "none",
"P1_LCV01Z": "none",
"P1_LIT01": "none",
"P1_PCV01D": "none",
"P1_PCV01Z": "none",
"P1_PCV02Z": "none",
"P1_PIT01": "none",
"P1_PIT02": "none",
"P1_TIT01": "none",
"P1_TIT02": "none",
"P2_24Vdc": "none",
"P2_CO_rpm": "none",
"P2_HILout": "none",
"P2_MSD": "none",
"P2_SIT01": "none",
"P2_SIT02": "none",
"P2_VT01": "none",
"P2_VXT02": "none",
"P2_VXT03": "none",
"P2_VYT02": "none",
"P2_VYT03": "none",
"P3_FIT01": "none",
"P3_LCP01D": "none",
"P3_LCV01D": "none",
"P3_LIT01": "none",
"P3_PIT01": "none",
"P4_HT_FD": "none",
"P4_HT_LD": "none",
"P4_HT_PO": "log1p",
"P4_LD": "none",
"P4_ST_FD": "none",
"P4_ST_GOV": "none",
"P4_ST_LD": "none",
"P4_ST_PO": "none",
"P4_ST_PT01": "none",
"P4_ST_TT01": "none"
},
"skew": {
"P1_B2004": -2.876938578031295e-05,
"P1_B2016": 2.014565216651284e-06,
"P1_B3004": 6.625985939357487e-06,
"P1_B3005": -9.917489652810193e-06,
"P1_B4002": 1.4641465884161855e-05,
"P1_B4005": -1.2370279269006856e-05,
"P1_B400B": -1.4116897198097317e-05,
"P1_B4022": 1.1162291352215598e-05,
"P1_FCV02Z": 2.532521501167817e-05,
"P1_FCV03D": 4.2517931711793e-06,
"P1_FCV03Z": 4.301856332440012e-06,
"P1_FT01": -1.3345735264961829e-05,
"P1_FT01Z": -4.2554413198354234e-05,
"P1_FT02": -1.0289230789249066e-05,
"P1_FT02Z": -1.4116856909216661e-05,
"P1_FT03": -4.341090521713463e-06,
"P1_FT03Z": -9.964308983887345e-06,
"P1_LCV01D": 2.541312481372867e-06,
"P1_LCV01Z": 2.5806433622267527e-06,
"P1_LIT01": 7.716120912717401e-06,
"P1_PCV01D": 2.113459306618771e-05,
"P1_PCV01Z": 2.0632832525407433e-05,
"P1_PCV02Z": 4.2639616636720384e-08,
"P1_PIT01": 2.079887220863843e-05,
"P1_PIT02": 5.003507344873546e-05,
"P1_TIT01": 9.553657000925262e-06,
"P1_TIT02": 2.1170357380515215e-05,
"P2_24Vdc": 2.735770838906968e-07,
"P2_CO_rpm": -8.124011608472296e-06,
"P2_HILout": -4.086282393330704e-06,
"P2_MSD": 0.0,
"P2_SIT01": -7.418240348817199e-06,
"P2_SIT02": -7.457826456660247e-06,
"P2_VT01": 1.247484205979928e-07,
"P2_VXT02": 6.53499778855353e-07,
"P2_VXT03": 5.32656056809399e-06,
"P2_VYT02": 9.483158480759724e-07,
"P2_VYT03": 2.128755351566922e-06,
"P3_FIT01": 2.2828575320599336e-05,
"P3_LCP01D": 1.3040993552131866e-05,
"P3_LCV01D": 3.781324885318626e-07,
"P3_LIT01": -7.824733758742217e-06,
"P3_PIT01": 3.210613447428708e-05,
"P4_HT_FD": 9.197840236384403e-05,
"P4_HT_LD": -2.4568845167931336e-08,
"P4_HT_PO": 3.997415489949367e-07,
"P4_LD": -6.253448074273654e-07,
"P4_ST_FD": 2.3472181460829935e-07,
"P4_ST_GOV": 2.494268873407866e-06,
"P4_ST_LD": 1.6692758818969547e-06,
"P4_ST_PO": 2.45129838870492e-06,
"P4_ST_PT01": 1.7637202837434092e-05,
"P4_ST_TT01": -1.9485876142550594e-05
},
"max_rows": 50000
}

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,8 @@ def main():
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
clip_k = cfg.get("clip_k", 5.0)
data_glob = cfg.get("data_glob", "")
data_path = cfg.get("data_path", "")
run([sys.executable, str(base_dir / "prepare_data.py")])
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
run(
@@ -70,7 +72,11 @@ def main():
"--use-ema",
]
)
run([sys.executable, str(base_dir / "evaluate_generated.py")])
ref = data_glob if data_glob else data_path
if ref:
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else:
run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "plot_loss.py")])

View File

@@ -47,6 +47,13 @@ def main():
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
model_num_layers = int(cfg.get("model_num_layers", 1))
model_dropout = float(cfg.get("model_dropout", 0.0))
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -67,6 +74,13 @@ def main():
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=model_time_dim,
hidden_dim=model_hidden_dim,
num_layers=model_num_layers,
dropout=model_dropout,
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,

View File

@@ -49,8 +49,17 @@ DEFAULTS = {
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": True,
"use_tanh_eps": False,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
}
@@ -144,6 +153,7 @@ def main():
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -164,6 +174,13 @@ def main():
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
@@ -198,6 +215,8 @@ def main():
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if use_condition:
@@ -215,7 +234,13 @@ def main():
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)