update
3
.gitignore
vendored
@@ -22,6 +22,9 @@ dataset/
|
||||
|
||||
# Model artifacts and results
|
||||
mask-ddpm/example/results/
|
||||
example/results/
|
||||
!example/results/cont_stats.json
|
||||
!example/results/disc_vocab.json
|
||||
*.pt
|
||||
*.pth
|
||||
*.ckpt
|
||||
|
||||
@@ -107,6 +107,14 @@ def lag1_corr(values: List[float]) -> float:
|
||||
def resolve_reference_paths(path: str) -> List[str]:
|
||||
if not path:
|
||||
return []
|
||||
if path.endswith(".json") and Path(path).exists():
|
||||
try:
|
||||
cfg = load_json(path)
|
||||
ref = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if ref:
|
||||
return resolve_reference_paths(str(ref))
|
||||
except Exception:
|
||||
return []
|
||||
if any(ch in path for ch in ["*", "?", "["]):
|
||||
base = Path(path).parent.resolve()
|
||||
pat = Path(path).name
|
||||
@@ -151,10 +159,11 @@ def main():
|
||||
std_ref = stats_json.get("raw_std", stats_json.get("std"))
|
||||
transforms = stats_json.get("transform", {})
|
||||
vocab = load_json(args.vocab)["vocab"]
|
||||
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||
vocab_sets = {c: set(vocab.get(c, {}).keys()) for c in disc_cols}
|
||||
|
||||
cont_stats = init_stats(cont_cols)
|
||||
disc_invalid = {c: 0 for c in disc_cols}
|
||||
missing_generated = {c: 0 for c in disc_cols}
|
||||
rows = 0
|
||||
|
||||
with open_csv(args.generated) as f:
|
||||
@@ -172,7 +181,11 @@ def main():
|
||||
if ref_paths:
|
||||
pass
|
||||
for c in disc_cols:
|
||||
if row[c] not in vocab_sets[c]:
|
||||
tok = row.get(c, None)
|
||||
if tok is None:
|
||||
missing_generated[c] += 1
|
||||
continue
|
||||
if tok not in vocab_sets[c]:
|
||||
disc_invalid[c] += 1
|
||||
|
||||
cont_summary = finalize_stats(cont_stats)
|
||||
@@ -192,6 +205,7 @@ def main():
|
||||
"continuous_summary": cont_summary,
|
||||
"continuous_error": cont_err,
|
||||
"discrete_invalid_counts": disc_invalid,
|
||||
"missing_generated_columns": {k: v for k, v in missing_generated.items() if v > 0},
|
||||
}
|
||||
|
||||
# Optional richer metrics using reference data
|
||||
@@ -212,7 +226,7 @@ def main():
|
||||
except Exception:
|
||||
gen_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
tok = row.get(c, "")
|
||||
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
|
||||
|
||||
loaded = 0
|
||||
@@ -228,7 +242,7 @@ def main():
|
||||
except Exception:
|
||||
ref_cont[c].append(0.0)
|
||||
for c in disc_cols:
|
||||
tok = row[c]
|
||||
tok = row.get(c, "")
|
||||
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
|
||||
loaded += 1
|
||||
if args.max_rows and loaded >= args.max_rows:
|
||||
|
||||
@@ -799,7 +799,9 @@ def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_row
|
||||
break
|
||||
return vals
|
||||
|
||||
ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
|
||||
ref_paths = []
|
||||
if reference_arg:
|
||||
ref_glob = resolve_reference_glob(reference_arg)
|
||||
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
|
||||
ref_rows = []
|
||||
if ref_paths:
|
||||
@@ -882,7 +884,11 @@ def temporal_only_cont_values(reference_arg, type_config_path, cont_stats, seq_l
|
||||
raise SystemExit("use_temporal_stage1 is not enabled in config_used/type-config; cannot plot temporal-only")
|
||||
|
||||
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
|
||||
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
|
||||
temporal_pt_p = Path(temporal_pt)
|
||||
if not temporal_pt_p.is_absolute():
|
||||
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
|
||||
else:
|
||||
temporal_pt = str(temporal_pt_p)
|
||||
|
||||
def load_torch_state(path: str, device):
|
||||
try:
|
||||
@@ -1196,7 +1202,11 @@ def lines_temporal_matplotlib(reference_arg, type_config_path, cont_stats, featu
|
||||
out_dir = str((Path(__file__).resolve().parent / "results").resolve())
|
||||
|
||||
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
|
||||
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
|
||||
temporal_pt_p = Path(temporal_pt)
|
||||
if not temporal_pt_p.is_absolute():
|
||||
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
|
||||
else:
|
||||
temporal_pt = str(temporal_pt_p)
|
||||
|
||||
used_cfg_path = Path(out_dir) / "config_used.json"
|
||||
if used_cfg_path.exists():
|
||||
@@ -1306,6 +1316,7 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
|
||||
if ref_mode == "index" and ref_paths:
|
||||
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
|
||||
ref_paths = [ref_paths[idx]]
|
||||
has_ref = bool(ref_paths)
|
||||
|
||||
edges_by_feat = {}
|
||||
g_hist_by_feat = {}
|
||||
@@ -1365,6 +1376,7 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
|
||||
xr = edges[1:]
|
||||
yr = ecdf_from_hist(r_hist_by_feat[feat])
|
||||
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
|
||||
if has_ref:
|
||||
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
|
||||
ax.set_title(feat, fontsize=9, loc="left")
|
||||
ax.set_ylim(0, 1)
|
||||
@@ -1418,6 +1430,7 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
|
||||
if ref_mode == "index" and ref_paths:
|
||||
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
|
||||
ref_paths = [ref_paths[idx]]
|
||||
has_ref = bool(ref_paths)
|
||||
|
||||
edges_by_feat = {}
|
||||
g_hist_by_feat = {}
|
||||
@@ -1474,6 +1487,7 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
|
||||
xr = edges[1:]
|
||||
yr = ecdf_from_hist(r_hist_by_feat[feat])
|
||||
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="temporal_only")
|
||||
if has_ref:
|
||||
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
|
||||
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
|
||||
ax.set_ylim(0, 1)
|
||||
@@ -1526,6 +1540,7 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
|
||||
if ref_mode == "index" and ref_paths:
|
||||
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
|
||||
ref_paths = [ref_paths[idx]]
|
||||
has_ref = bool(ref_paths)
|
||||
|
||||
edges_by_feat = {}
|
||||
g_hist_by_feat = {}
|
||||
@@ -1585,6 +1600,7 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
|
||||
xr = edges[1:]
|
||||
yr = ecdf_from_hist(r_hist_by_feat[feat])
|
||||
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
|
||||
if has_ref:
|
||||
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
|
||||
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
121
figures/README.md
Normal file
@@ -0,0 +1,121 @@
|
||||
# Figures 说明
|
||||
|
||||
本目录里的图片主要由 [plot_benchmark.py](file:///f:/Development/modbus_diffusion/mask-ddpm/mask-ddpm/example/plot_benchmark.py) 生成,用来对比:
|
||||
|
||||
- generated:完整管线采样导出的 `example/results/generated.csv`(或 post 后的 `generated_post.csv`)
|
||||
- temporal_only:只使用 `temporal.pt`(temporal stage1)的输出
|
||||
- real:从 `--reference` 指定的真实训练数据(通常是 `example/config.json` 里的 `data_glob` / `data_path` 指向的 `train*.csv.gz`)
|
||||
|
||||
## 通用图例
|
||||
|
||||
- 蓝色:模型输出(generated 或 temporal_only,具体看文件名)
|
||||
- 红色:真实数据(real)
|
||||
- 文件名包含 `with_real`:一定画了红色 real
|
||||
- 文件名包含 `generated_only`:只画模型输出(没有 real)
|
||||
- 文件名包含 `temporal`:使用 `temporal.pt`(只时序 stage1)
|
||||
- 文件名包含 `ref384`:真实数据读取行数被限制到 384(通常用于和 generated 行数对齐)
|
||||
|
||||
## 每张图是干嘛的
|
||||
|
||||
### 总览与指标
|
||||
|
||||
- [benchmark_panel.svg](./benchmark_panel.svg)
|
||||
- 作用:一张图看完整评估概况(四个 panel)
|
||||
- A:连续特征均值的“generated vs real”一致性轮廓(按 real min/max 做范围归一化)
|
||||
- B:每个特征的 KS 分布差异(越小越好)
|
||||
- C:不同训练文件之间的均值漂移(数据集 shift)
|
||||
- D:多随机种子鲁棒性 + 指标历史曲线
|
||||
- 生成:`python example/plot_benchmark.py --figure panel`
|
||||
|
||||
- [benchmark_metrics.svg](./benchmark_metrics.svg)
|
||||
- 作用:只看 3 个核心指标在不同 seed 下的均值/方差(越小越好)
|
||||
- 指标:avg_ks(连续)、avg_jsd(离散)、avg_lag1_diff(滞后 1 自相关差异)
|
||||
- 生成:`python example/plot_benchmark.py --figure summary`
|
||||
|
||||
- [ranked_ks.svg](./ranked_ks.svg)
|
||||
- 作用:定位“哪些特征拖累了平均 KS”
|
||||
- 左侧:top-K KS 最大的特征
|
||||
- 右侧:移除最差特征后,平均 KS 会下降多少
|
||||
- 生成:`python example/plot_benchmark.py --figure ranked_ks`
|
||||
|
||||
### 线图(时间序列形状)
|
||||
|
||||
- [lines.svg](./lines.svg)
|
||||
- 作用:对比 generated vs real 的时间序列曲线(选定若干特征)
|
||||
- 输入:`example/results/generated.csv` + `--reference` 指向的 real
|
||||
- 生成:`python example/plot_benchmark.py --figure lines --reference example/config.json`
|
||||
|
||||
- [lines_generated_type4.svg](./lines_generated_type4.svg)
|
||||
- 作用:Type4(`P1_PIT02,P2_SIT02,P1_FT03`)的 generated 时间序列(可能不含 real,取决于生成时的参数)
|
||||
- 生成:通常来自 `--figure lines --lines-features ... --out ...`
|
||||
|
||||
- [lines_generated_type4_with_real.svg](./lines_generated_type4_with_real.svg)
|
||||
- 作用:Type4 的 generated vs real 时间序列对比
|
||||
- 生成(示例):
|
||||
- `python example/plot_benchmark.py --figure lines --generated example/results/generated.csv --reference example/config.json --cont-stats example/results/cont_stats.json --lines-features P1_PIT02,P2_SIT02,P1_FT03 --lines-max-rows 2048 --out figures/lines_generated_type4_with_real.svg`
|
||||
|
||||
- [lines_temporal.svg](./lines_temporal.svg)
|
||||
- 作用:temporal_only vs real 的时间序列曲线(只时序 stage1)
|
||||
- 输入:`temporal.pt` + `--reference` 指向的 real
|
||||
- 生成:`python example/plot_benchmark.py --figure lines_temporal --reference example/config.json --temporal-pt <path/to/temporal.pt>`
|
||||
|
||||
- [lines_temporal_type4.svg](./lines_temporal_type4.svg)
|
||||
- 作用:Type4 的 temporal_only 时间序列(可能不含 real,取决于生成时的参数)
|
||||
|
||||
- [lines_temporal_type4_with_real.svg](./lines_temporal_type4_with_real.svg)
|
||||
- 作用:Type4 的 temporal_only vs real 时间序列对比
|
||||
|
||||
### 分布图(CDF,对齐边缘/中位数/尾部)
|
||||
|
||||
- [cdf_grid.svg](./cdf_grid.svg)
|
||||
- 作用:连续特征的 CDF 网格(generated vs real)
|
||||
- 生成:`python example/plot_benchmark.py --figure cdf_grid --reference example/config.json`
|
||||
|
||||
- [cdf_grid_all.svg](./cdf_grid_all.svg)
|
||||
- 作用:`cdf_grid.svg` 的“更多特征/全量特征”版本(通常是 `--cdf-max-features 0` 或更大的上限)
|
||||
|
||||
- [cdf_grid_types.svg](./cdf_grid_types.svg)
|
||||
- 作用:按 Type1~Type6 分组的连续特征 CDF(generated vs real)
|
||||
- 生成:`python example/plot_benchmark.py --figure cdf_grid_types --reference example/config.json --type-config example/config.json`
|
||||
|
||||
- [cdf_grid_types_generated_only.svg](./cdf_grid_types_generated_only.svg)
|
||||
- 作用:按类型分组的连续特征 CDF,只画 generated(没有 real)
|
||||
|
||||
- [cdf_grid_types_generated_with_real.svg](./cdf_grid_types_generated_with_real.svg)
|
||||
- 作用:按类型分组的连续特征 CDF,generated vs real
|
||||
|
||||
- [cdf_grid_types_ref384.svg](./cdf_grid_types_ref384.svg)
|
||||
- 作用:按类型分组的连续特征 CDF,但 real 侧只读取 384 行(方便与 generated 行数对齐做“同样本量”对比)
|
||||
|
||||
- [cdf_grid_types_temporal.svg](./cdf_grid_types_temporal.svg)
|
||||
- 作用:按类型分组的连续特征 CDF,temporal_only vs real
|
||||
|
||||
- [cdf_grid_types_temporal_fast.svg](./cdf_grid_types_temporal_fast.svg)
|
||||
- 作用:`cdf_grid_types_temporal` 的快速版本(通常 seq_len/bins 更小)
|
||||
|
||||
- [cdf_grid_types_temporal_ref384.svg](./cdf_grid_types_temporal_ref384.svg)
|
||||
- 作用:按类型分组的 temporal_only CDF,但 real 侧只读取 384 行
|
||||
|
||||
- [cdf_grid_types_temporal_with_real.svg](./cdf_grid_types_temporal_with_real.svg)
|
||||
- 作用:按类型分组的 temporal_only vs real CDF
|
||||
|
||||
### 离散特征分布
|
||||
|
||||
- [disc_grid.svg](./disc_grid.svg)
|
||||
- 作用:离散特征的分布网格(generated vs real)
|
||||
- 生成:`python example/plot_benchmark.py --figure disc_grid --reference example/config.json`
|
||||
|
||||
- [disc_points.svg](./disc_points.svg)
|
||||
- 作用:离散特征的“类别占比对比点图”(每个类别两点:generated 与 real)
|
||||
- 生成:`python example/plot_benchmark.py --figure disc_points --reference example/config.json`
|
||||
|
||||
### Smoke / 临时检查图
|
||||
|
||||
以下 `*_smoke_*.svg` 多数是临时/快速检查用输出(用于验证脚本跑通、参数组合是否正常),不一定是最终报告图:
|
||||
|
||||
- [_smoke_cdf_all.svg](./_smoke_cdf_all.svg)
|
||||
- [_smoke_cdf_types_temporal_exclude.svg](./_smoke_cdf_types_temporal_exclude.svg)
|
||||
- [_smoke_disc_all.svg](./_smoke_disc_all.svg)
|
||||
- [_smoke_lines_temporal_cond.svg](./_smoke_lines_temporal_cond.svg)
|
||||
- [_smoke_lines_temporal_seeded.svg](./_smoke_lines_temporal_seeded.svg)
|
||||
|
||||
16788
figures/cdf_grid_types_generated_only.svg
Normal file
|
After Width: | Height: | Size: 540 KiB |
20142
figures/cdf_grid_types_generated_with_real.svg
Normal file
|
After Width: | Height: | Size: 628 KiB |
|
Before Width: | Height: | Size: 629 KiB After Width: | Height: | Size: 540 KiB |
20141
figures/cdf_grid_types_temporal_with_real.svg
Normal file
|
After Width: | Height: | Size: 629 KiB |
5845
figures/lines_generated_type4_with_real.svg
Normal file
|
After Width: | Height: | Size: 145 KiB |
|
Before Width: | Height: | Size: 191 KiB After Width: | Height: | Size: 103 KiB |
7966
figures/lines_temporal_type4_with_real.svg
Normal file
|
After Width: | Height: | Size: 196 KiB |