diff --git a/example/README.md b/example/README.md index b88330f..04d2cb4 100644 --- a/example/README.md +++ b/example/README.md @@ -72,3 +72,4 @@ python example/run_pipeline.py --device auto - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. +- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels. diff --git a/example/config.json b/example/config.json index 90ff630..8e81f00 100644 --- a/example/config.json +++ b/example/config.json @@ -32,6 +32,9 @@ "model_ff_mult": 2, "model_pos_dim": 64, "model_use_pos_embed": true, + "model_use_feature_graph": true, + "feature_graph_scale": 0.1, + "feature_graph_dropout": 0.0, "disc_mask_scale": 0.9, "cont_loss_weighting": "inv_std", "cont_loss_eps": 1e-6, diff --git a/example/export_samples.py b/example/export_samples.py index 65af6b6..0809bd9 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -151,6 +151,9 @@ def main(): ff_mult=int(cfg.get("model_ff_mult", 2)), pos_dim=int(cfg.get("model_pos_dim", 64)), use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), + use_feature_graph=bool(cfg.get("model_use_feature_graph", False)), + feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)), + feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)), cond_vocab_size=cond_vocab_size if use_condition else 0, cond_dim=int(cfg.get("cond_dim", 32)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index 4f37356..4245d91 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -66,6 +66,24 @@ class SinusoidalTimeEmbedding(nn.Module): return emb +class FeatureGraphMixer(nn.Module): + """Learnable feature relation mixer (dataset-agnostic).""" + + def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0): + super().__init__() + self.scale = scale + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.A = nn.Parameter(torch.zeros(dim, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, D) + # Symmetric relation to stabilize + A = (self.A + self.A.t()) * 0.5 + mixed = torch.matmul(x, A) * self.scale + mixed = self.dropout(mixed) + return x + mixed + + class HybridDiffusionModel(nn.Module): def __init__( self, @@ -78,6 +96,9 @@ class HybridDiffusionModel(nn.Module): ff_mult: int = 2, pos_dim: int = 64, use_pos_embed: bool = True, + use_feature_graph: bool = False, + feature_graph_scale: float = 0.1, + feature_graph_dropout: float = 0.0, cond_vocab_size: int = 0, cond_dim: int = 32, use_tanh_eps: bool = False, @@ -92,6 +113,7 @@ class HybridDiffusionModel(nn.Module): self.eps_scale = eps_scale self.pos_dim = pos_dim self.use_pos_embed = use_pos_embed + self.use_feature_graph = use_feature_graph self.cond_vocab_size = cond_vocab_size self.cond_dim = cond_dim @@ -106,8 +128,17 @@ class HybridDiffusionModel(nn.Module): disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) + self.feature_dim = cont_dim + disc_embed_dim + if use_feature_graph: + self.feature_graph = FeatureGraphMixer( + self.feature_dim, + scale=feature_graph_scale, + dropout=feature_graph_dropout, + ) + else: + self.feature_graph = None pos_dim = pos_dim if use_pos_embed else 0 - in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) + in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) self.in_proj = nn.Linear(in_dim, hidden_dim) self.backbone = nn.GRU( hidden_dim, @@ -149,7 +180,10 @@ class HybridDiffusionModel(nn.Module): cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1) cont_feat = self.cont_proj(x_cont) - parts = [cont_feat, disc_feat, time_emb] + feat = torch.cat([cont_feat, disc_feat], dim=-1) + if self.feature_graph is not None: + feat = self.feature_graph(feat) + parts = [feat, time_emb] if pos_emb is not None: parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) if cond_feat is not None: diff --git a/example/sample.py b/example/sample.py index 2c4caf6..14cda54 100755 --- a/example/sample.py +++ b/example/sample.py @@ -56,6 +56,9 @@ def main(): model_ff_mult = int(cfg.get("model_ff_mult", 2)) model_pos_dim = int(cfg.get("model_pos_dim", 64)) model_use_pos = bool(cfg.get("model_use_pos_embed", True)) + model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False)) + feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1)) + feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0)) split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") @@ -83,6 +86,9 @@ def main(): ff_mult=model_ff_mult, pos_dim=model_pos_dim, use_pos_embed=model_use_pos, + use_feature_graph=model_use_feature_graph, + feature_graph_scale=feature_graph_scale, + feature_graph_dropout=feature_graph_dropout, cond_vocab_size=cond_vocab_size, cond_dim=cond_dim, use_tanh_eps=use_tanh_eps, diff --git a/example/train.py b/example/train.py index 524464e..025f3d5 100755 --- a/example/train.py +++ b/example/train.py @@ -58,6 +58,9 @@ DEFAULTS = { "model_ff_mult": 2, "model_pos_dim": 64, "model_use_pos_embed": True, + "model_use_feature_graph": True, + "feature_graph_scale": 0.1, + "feature_graph_dropout": 0.0, "disc_mask_scale": 0.9, "shuffle_buffer": 256, "cont_loss_weighting": "none", # none | inv_std @@ -193,6 +196,9 @@ def main(): ff_mult=int(config.get("model_ff_mult", 2)), pos_dim=int(config.get("model_pos_dim", 64)), use_pos_embed=bool(config.get("model_use_pos_embed", True)), + use_feature_graph=bool(config.get("model_use_feature_graph", False)), + feature_graph_scale=float(config.get("feature_graph_scale", 0.1)), + feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)), cond_vocab_size=cond_vocab_size, cond_dim=int(config.get("cond_dim", 32)), use_tanh_eps=bool(config.get("use_tanh_eps", False)),