transformer
This commit is contained in:
12
docs/README.md
Normal file
12
docs/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Documentation Index
|
||||
|
||||
This folder tracks project decisions, experiments, and evolving ideas.
|
||||
|
||||
- `decisions.md`: design/architecture changes and rationales
|
||||
- `experiments.md`: experiment runs and results
|
||||
- `ideas.md`: future ideas and hypotheses
|
||||
|
||||
Conventions:
|
||||
- Append new entries instead of overwriting old ones.
|
||||
- Record exact config file and key overrides when possible.
|
||||
- Keep metrics in the order: avg_ks / avg_jsd / avg_lag1_diff.
|
||||
35
docs/decisions.md
Normal file
35
docs/decisions.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Design & Decision Log
|
||||
|
||||
## 2026-01-26 — Two-stage temporal backbone (GRU) + residual diffusion
|
||||
- **Decision**: Add a stage-1 GRU trend model, then train diffusion on residuals.
|
||||
- **Why**: Separate temporal consistency from distribution alignment.
|
||||
- **Files**:
|
||||
- `example/hybrid_diffusion.py` (added `TemporalGRUGenerator`)
|
||||
- `example/train.py` (two-stage training + residual diffusion)
|
||||
- `example/sample.py`, `example/export_samples.py` (trend + residual synthesis)
|
||||
- `example/config.json` (temporal hyperparameters)
|
||||
- **Expected effect**: improve lag-1 consistency; may hurt KS if residual distribution drifts.
|
||||
|
||||
## 2026-01-26 — Residual distribution alignment losses
|
||||
- **Decision**: Apply distribution losses to residuals (not raw x0).
|
||||
- **Why**: Diffusion models residuals; alignment should match residual distribution.
|
||||
- **Files**:
|
||||
- `example/train.py` (quantile loss on residuals)
|
||||
- `example/config.json` (quantile weight)
|
||||
|
||||
## 2026-01-26 — SNR-weighted loss + residual stats
|
||||
- **Decision**: Add SNR-weighted loss and residual mean/std regularization.
|
||||
- **Why**: Stabilize diffusion training and improve KS.
|
||||
- **Files**:
|
||||
- `example/train.py`
|
||||
- `example/config.json`
|
||||
|
||||
## 2026-01-26 — Switchable backbone (GRU vs Transformer)
|
||||
- **Decision**: Make the diffusion backbone configurable (`backbone_type`) with a Transformer encoder option.
|
||||
- **Why**: Test whether self‑attention reduces temporal vs distribution competition without altering the two‑stage design.
|
||||
- **Files**:
|
||||
- `example/hybrid_diffusion.py`
|
||||
- `example/train.py`
|
||||
- `example/sample.py`
|
||||
- `example/export_samples.py`
|
||||
- `example/config.json`
|
||||
29
docs/experiments.md
Normal file
29
docs/experiments.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Experiment Log
|
||||
|
||||
## Format
|
||||
```
|
||||
YYYY-MM-DD
|
||||
- Config: <config file or key overrides>
|
||||
- Result: avg_ks / avg_jsd / avg_lag1_diff
|
||||
- Notes
|
||||
```
|
||||
|
||||
## 2026-01-26
|
||||
- Config: `example/config_no_temporal.json` (baseline)
|
||||
- Result: 0.6474156 / 0.0576699 / 0.1981700
|
||||
- Notes: no temporal stage; better KS, worse lag-1.
|
||||
|
||||
## 2026-01-26
|
||||
- Config: `example/config_temporal_strong.json` (two-stage)
|
||||
- Result: 0.6892453 / 0.0564408 / 0.1568776
|
||||
- Notes: lag-1 improves, KS degrades; residual drift remains.
|
||||
|
||||
## 2026-01-26
|
||||
- Config: `example/config.json` (two-stage residual diffusion; user run on Windows)
|
||||
- Result: 0.7131993 / 0.0327603 / 0.2327633
|
||||
- Notes: user-reported metrics after temporal stage + residual diffusion.
|
||||
|
||||
## 2026-01-26
|
||||
- Config: `example/config.json` (two-stage residual diffusion; user run on Windows)
|
||||
- Result: 0.7096230 / 0.0331810 / 0.1898416
|
||||
- Notes: slight KS improvement, lag-1 improves; still distribution/temporal trade-off.
|
||||
13
docs/ideas.md
Normal file
13
docs/ideas.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Ideas & Hypotheses
|
||||
|
||||
## Transformer as backbone (Plan B)
|
||||
- Hypothesis: self-attention may better capture long-range dependencies and reduce conflict between temporal consistency and distribution matching.
|
||||
- Risk: higher compute cost, potentially more unstable training.
|
||||
- Status: implemented as `backbone_type = "transformer"` in config.
|
||||
- Experiment: compare GRU vs Transformer using `run_compare.py`.
|
||||
|
||||
## Residual standardization
|
||||
- Hypothesis: standardizing residuals before diffusion reduces drift and improves KS.
|
||||
|
||||
## Two-stage training with curriculum
|
||||
- Hypothesis: train diffusion on residuals only after temporal GRU converges to low error.
|
||||
@@ -32,6 +32,11 @@
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"backbone_type": "transformer",
|
||||
"transformer_num_layers": 2,
|
||||
"transformer_nhead": 4,
|
||||
"transformer_ff_dim": 512,
|
||||
"transformer_dropout": 0.1,
|
||||
"disc_mask_scale": 0.9,
|
||||
"cont_loss_weighting": "inv_std",
|
||||
"cont_loss_eps": 1e-6,
|
||||
|
||||
@@ -32,6 +32,11 @@
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"backbone_type": "transformer",
|
||||
"transformer_num_layers": 2,
|
||||
"transformer_nhead": 4,
|
||||
"transformer_ff_dim": 512,
|
||||
"transformer_dropout": 0.1,
|
||||
"disc_mask_scale": 0.9,
|
||||
"cont_loss_weighting": "inv_std",
|
||||
"cont_loss_eps": 1e-6,
|
||||
|
||||
@@ -32,6 +32,11 @@
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"backbone_type": "transformer",
|
||||
"transformer_num_layers": 2,
|
||||
"transformer_nhead": 4,
|
||||
"transformer_ff_dim": 512,
|
||||
"transformer_dropout": 0.1,
|
||||
"disc_mask_scale": 0.9,
|
||||
"cont_loss_weighting": "inv_std",
|
||||
"cont_loss_eps": 1e-6,
|
||||
|
||||
@@ -144,6 +144,11 @@ def main():
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
||||
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
@@ -155,6 +160,11 @@ def main():
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
backbone_type=backbone_type,
|
||||
transformer_num_layers=transformer_num_layers,
|
||||
transformer_nhead=transformer_nhead,
|
||||
transformer_ff_dim=transformer_ff_dim,
|
||||
transformer_dropout=transformer_dropout,
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
|
||||
@@ -118,6 +118,11 @@ class HybridDiffusionModel(nn.Module):
|
||||
ff_mult: int = 2,
|
||||
pos_dim: int = 64,
|
||||
use_pos_embed: bool = True,
|
||||
backbone_type: str = "gru", # gru | transformer
|
||||
transformer_num_layers: int = 4,
|
||||
transformer_nhead: int = 8,
|
||||
transformer_ff_dim: int = 2048,
|
||||
transformer_dropout: float = 0.1,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
@@ -132,6 +137,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
self.eps_scale = eps_scale
|
||||
self.pos_dim = pos_dim
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.backbone_type = backbone_type
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
@@ -149,13 +155,24 @@ class HybridDiffusionModel(nn.Module):
|
||||
pos_dim = pos_dim if use_pos_embed else 0
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
if backbone_type == "transformer":
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=transformer_nhead,
|
||||
dim_feedforward=transformer_ff_dim,
|
||||
dropout=transformer_dropout,
|
||||
batch_first=True,
|
||||
activation="gelu",
|
||||
)
|
||||
self.backbone = nn.TransformerEncoder(encoder_layer, num_layers=transformer_num_layers)
|
||||
else:
|
||||
self.backbone = nn.GRU(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.post_norm = nn.LayerNorm(hidden_dim)
|
||||
self.post_ff = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim * ff_mult),
|
||||
@@ -197,7 +214,10 @@ class HybridDiffusionModel(nn.Module):
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
out, _ = self.backbone(feat)
|
||||
if self.backbone_type == "transformer":
|
||||
out = self.backbone(feat)
|
||||
else:
|
||||
out, _ = self.backbone(feat)
|
||||
out = self.post_norm(out)
|
||||
out = out + self.post_ff(out)
|
||||
|
||||
|
||||
@@ -60,6 +60,11 @@ def main():
|
||||
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
||||
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
||||
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
||||
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
@@ -87,6 +92,11 @@ def main():
|
||||
ff_mult=model_ff_mult,
|
||||
pos_dim=model_pos_dim,
|
||||
use_pos_embed=model_use_pos,
|
||||
backbone_type=backbone_type,
|
||||
transformer_num_layers=transformer_num_layers,
|
||||
transformer_nhead=transformer_nhead,
|
||||
transformer_ff_dim=transformer_ff_dim,
|
||||
transformer_dropout=transformer_dropout,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
|
||||
@@ -200,6 +200,11 @@ def main():
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
backbone_type=str(config.get("backbone_type", "gru")),
|
||||
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
|
||||
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
||||
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
||||
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
|
||||
21
report.md
21
report.md
@@ -58,8 +58,9 @@ Defined in `example/hybrid_diffusion.py`.
|
||||
- Positional embedding (sequence index)
|
||||
- Optional condition embedding (`file_id`)
|
||||
|
||||
**Backbone:**
|
||||
**Backbone (configurable):**
|
||||
- GRU (sequence modeling)
|
||||
- Transformer encoder (self‑attention)
|
||||
- Post LayerNorm + residual MLP
|
||||
|
||||
**Outputs:**
|
||||
@@ -120,6 +121,7 @@ L = λ * L_cont + (1 − λ) * L_disc
|
||||
- `eps` target: MSE(eps_pred, eps)
|
||||
- `x0` target: MSE(x0_pred, x0)
|
||||
- Optional inverse-variance weighting: `cont_loss_weighting = "inv_std"`
|
||||
- Optional **SNR-weighted loss**: reweights MSE by SNR to stabilize diffusion training
|
||||
|
||||
### 6.2 Discrete Loss / 离散损失
|
||||
Cross-entropy on masked positions only.
|
||||
@@ -130,6 +132,10 @@ Stage‑1 GRU predicts next step:
|
||||
L_temporal = MSE(pred_next, x[:,1:])
|
||||
```
|
||||
|
||||
### 6.4 Residual Alignment Losses / 残差对齐损失
|
||||
- **Quantile loss** on residuals to align distribution tails.
|
||||
- **Residual mean/std penalty** to reduce drift and improve KS.
|
||||
|
||||
---
|
||||
|
||||
## 7. Data Processing / 数据处理
|
||||
@@ -169,18 +175,26 @@ Metrics (with reference):
|
||||
- **Discrete JSD** over vocab frequency
|
||||
- **Invalid token counts**
|
||||
|
||||
**指标汇总与对比脚本:** `example/summary_metrics.py`\n- 输出 avg_ks / avg_jsd / avg_lag1_diff\n- 追加记录到 `example/results/metrics_history.csv`\n- 如果存在上一次记录,输出 delta(新旧对比)
|
||||
**指标汇总与对比脚本:** `example/summary_metrics.py`
|
||||
- 输出 avg_ks / avg_jsd / avg_lag1_diff
|
||||
- 追加记录到 `example/results/metrics_history.csv`
|
||||
- 如果存在上一次记录,输出 delta(新旧对比)
|
||||
|
||||
Recent run (user-reported, Windows):
|
||||
- avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984
|
||||
|
||||
---
|
||||
|
||||
## 10. Automation / 自动化
|
||||
`example/run_all.py` runs all stages with config-driven paths.
|
||||
`example/run_compare.py` can run a baseline vs temporal config and compute metric deltas.
|
||||
|
||||
---
|
||||
|
||||
## 11. Key Engineering Decisions / 关键工程决策
|
||||
- Mixed-type diffusion: continuous + discrete split
|
||||
- Two-stage training: temporal backbone first, diffusion on residuals
|
||||
- Switchable backbone: GRU vs Transformer encoder for the diffusion model
|
||||
- Positional + time embeddings for stability
|
||||
- Optional inverse-variance weighting for continuous loss
|
||||
- Log1p transforms for heavy-tailed signals
|
||||
@@ -205,11 +219,12 @@ Metrics (with reference):
|
||||
- KS may remain high → continuous distribution mismatch
|
||||
- Lag‑1 may fluctuate → distribution vs temporal trade-off
|
||||
- Continuous loss may dominate → needs careful weighting
|
||||
- Transformer backbone may change stability; needs systematic comparison
|
||||
|
||||
---
|
||||
|
||||
## 14. Suggested Next Steps / 下一步建议
|
||||
- Add **SNR-weighted loss** for stable diffusion training
|
||||
- Compare GRU vs Transformer backbone using `run_compare.py`
|
||||
- Explore **v‑prediction** for continuous branch
|
||||
- Strengthen discrete diffusion (e.g., D3PM-style transitions)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user