update新结构
This commit is contained in:
@@ -66,6 +66,24 @@ class SinusoidalTimeEmbedding(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class FeatureGraphMixer(nn.Module):
|
||||
"""Learnable feature relation mixer (dataset-agnostic)."""
|
||||
|
||||
def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.A = nn.Parameter(torch.zeros(dim, dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, T, D)
|
||||
# Symmetric relation to stabilize
|
||||
A = (self.A + self.A.t()) * 0.5
|
||||
mixed = torch.matmul(x, A) * self.scale
|
||||
mixed = self.dropout(mixed)
|
||||
return x + mixed
|
||||
|
||||
|
||||
class HybridDiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -78,6 +96,9 @@ class HybridDiffusionModel(nn.Module):
|
||||
ff_mult: int = 2,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
use_feature_graph: bool = False,
|
||||
feature_graph_scale: float = 0.1,
|
||||
feature_graph_dropout: float = 0.0,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -92,6 +113,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.eps_scale = eps_scale
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.use_feature_graph = use_feature_graph
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
@@ -106,8 +128,17 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||
|
||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||
self.feature_dim = cont_dim + disc_embed_dim
|
||||
if use_feature_graph:
|
||||
self.feature_graph = FeatureGraphMixer(
|
||||
self.feature_dim,
|
||||
scale=feature_graph_scale,
|
||||
dropout=feature_graph_dropout,
|
||||
)
|
||||
else:
|
||||
self.feature_graph = None
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
@@ -149,7 +180,10 @@ class HybridDiffusionModel(nn.Module):
|
||||
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
|
||||
cont_feat = self.cont_proj(x_cont)
|
||||
parts = [cont_feat, disc_feat, time_emb]
|
||||
feat = torch.cat([cont_feat, disc_feat], dim=-1)
|
||||
if self.feature_graph is not None:
|
||||
feat = self.feature_graph(feat)
|
||||
parts = [feat, time_emb]
|
||||
if pos_emb is not None:
|
||||
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
||||
if cond_feat is not None:
|
||||
|
||||
Reference in New Issue
Block a user