transformer

This commit is contained in:
2026-01-27 00:41:42 +08:00
parent 65391910a2
commit 334db7082b
12 changed files with 175 additions and 11 deletions

View File

@@ -58,8 +58,9 @@ Defined in `example/hybrid_diffusion.py`.
- Positional embedding (sequence index)
- Optional condition embedding (`file_id`)
**Backbone:**
**Backbone (configurable):**
- GRU (sequence modeling)
- Transformer encoder (selfattention)
- Post LayerNorm + residual MLP
**Outputs:**
@@ -120,6 +121,7 @@ L = λ * L_cont + (1 λ) * L_disc
- `eps` target: MSE(eps_pred, eps)
- `x0` target: MSE(x0_pred, x0)
- Optional inverse-variance weighting: `cont_loss_weighting = "inv_std"`
- Optional **SNR-weighted loss**: reweights MSE by SNR to stabilize diffusion training
### 6.2 Discrete Loss / 离散损失
Cross-entropy on masked positions only.
@@ -130,6 +132,10 @@ Stage1 GRU predicts next step:
L_temporal = MSE(pred_next, x[:,1:])
```
### 6.4 Residual Alignment Losses / 残差对齐损失
- **Quantile loss** on residuals to align distribution tails.
- **Residual mean/std penalty** to reduce drift and improve KS.
---
## 7. Data Processing / 数据处理
@@ -169,18 +175,26 @@ Metrics (with reference):
- **Discrete JSD** over vocab frequency
- **Invalid token counts**
**指标汇总与对比脚本:** `example/summary_metrics.py`\n- 输出 avg_ks / avg_jsd / avg_lag1_diff\n- 追加记录到 `example/results/metrics_history.csv`\n- 如果存在上一次记录,输出 delta新旧对比
**指标汇总与对比脚本:** `example/summary_metrics.py`
- 输出 avg_ks / avg_jsd / avg_lag1_diff
- 追加记录到 `example/results/metrics_history.csv`
- 如果存在上一次记录,输出 delta新旧对比
Recent run (user-reported, Windows):
- avg_ks 0.7096 / avg_jsd 0.03318 / avg_lag1_diff 0.18984
---
## 10. Automation / 自动化
`example/run_all.py` runs all stages with config-driven paths.
`example/run_compare.py` can run a baseline vs temporal config and compute metric deltas.
---
## 11. Key Engineering Decisions / 关键工程决策
- Mixed-type diffusion: continuous + discrete split
- Two-stage training: temporal backbone first, diffusion on residuals
- Switchable backbone: GRU vs Transformer encoder for the diffusion model
- Positional + time embeddings for stability
- Optional inverse-variance weighting for continuous loss
- Log1p transforms for heavy-tailed signals
@@ -205,11 +219,12 @@ Metrics (with reference):
- KS may remain high → continuous distribution mismatch
- Lag1 may fluctuate → distribution vs temporal trade-off
- Continuous loss may dominate → needs careful weighting
- Transformer backbone may change stability; needs systematic comparison
---
## 14. Suggested Next Steps / 下一步建议
- Add **SNR-weighted loss** for stable diffusion training
- Compare GRU vs Transformer backbone using `run_compare.py`
- Explore **vprediction** for continuous branch
- Strengthen discrete diffusion (e.g., D3PM-style transitions)