update新结构

This commit is contained in:
2026-01-26 18:27:41 +08:00
parent bc838d7cd7
commit cb610281ce
6 changed files with 55 additions and 2 deletions

View File

@@ -72,3 +72,4 @@ python example/run_pipeline.py --device auto
- The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.

View File

@@ -32,6 +32,9 @@
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6,

View File

@@ -151,6 +151,9 @@ def main():
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
use_feature_graph=bool(cfg.get("model_use_feature_graph", False)),
feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),

View File

@@ -66,6 +66,24 @@ class SinusoidalTimeEmbedding(nn.Module):
return emb
class FeatureGraphMixer(nn.Module):
"""Learnable feature relation mixer (dataset-agnostic)."""
def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0):
super().__init__()
self.scale = scale
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.A = nn.Parameter(torch.zeros(dim, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, D)
# Symmetric relation to stabilize
A = (self.A + self.A.t()) * 0.5
mixed = torch.matmul(x, A) * self.scale
mixed = self.dropout(mixed)
return x + mixed
class HybridDiffusionModel(nn.Module):
def __init__(
self,
@@ -78,6 +96,9 @@ class HybridDiffusionModel(nn.Module):
ff_mult: int = 2,
pos_dim: int = 64,
use_pos_embed: bool = True,
use_feature_graph: bool = False,
feature_graph_scale: float = 0.1,
feature_graph_dropout: float = 0.0,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
@@ -92,6 +113,7 @@ class HybridDiffusionModel(nn.Module):
self.eps_scale = eps_scale
self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed
self.use_feature_graph = use_feature_graph
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
@@ -106,8 +128,17 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim)
self.feature_dim = cont_dim + disc_embed_dim
if use_feature_graph:
self.feature_graph = FeatureGraphMixer(
self.feature_dim,
scale=feature_graph_scale,
dropout=feature_graph_dropout,
)
else:
self.feature_graph = None
pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(
hidden_dim,
@@ -149,7 +180,10 @@ class HybridDiffusionModel(nn.Module):
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb]
feat = torch.cat([cont_feat, disc_feat], dim=-1)
if self.feature_graph is not None:
feat = self.feature_graph(feat)
parts = [feat, time_emb]
if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None:

View File

@@ -56,6 +56,9 @@ def main():
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -83,6 +86,9 @@ def main():
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
use_feature_graph=model_use_feature_graph,
feature_graph_scale=feature_graph_scale,
feature_graph_dropout=feature_graph_dropout,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,

View File

@@ -58,6 +58,9 @@ DEFAULTS = {
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": True,
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
@@ -193,6 +196,9 @@ def main():
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),