This commit is contained in:
MZ YANG
2026-02-12 01:46:47 +08:00
parent 26b9b8a447
commit f1afd4bf38
11 changed files with 77395 additions and 13341 deletions

3
.gitignore vendored
View File

@@ -22,6 +22,9 @@ dataset/
# Model artifacts and results
mask-ddpm/example/results/
example/results/
!example/results/cont_stats.json
!example/results/disc_vocab.json
*.pt
*.pth
*.ckpt

View File

@@ -107,6 +107,14 @@ def lag1_corr(values: List[float]) -> float:
def resolve_reference_paths(path: str) -> List[str]:
if not path:
return []
if path.endswith(".json") and Path(path).exists():
try:
cfg = load_json(path)
ref = cfg.get("data_glob") or cfg.get("data_path") or ""
if ref:
return resolve_reference_paths(str(ref))
except Exception:
return []
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent.resolve()
pat = Path(path).name
@@ -151,10 +159,11 @@ def main():
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
vocab_sets = {c: set(vocab.get(c, {}).keys()) for c in disc_cols}
cont_stats = init_stats(cont_cols)
disc_invalid = {c: 0 for c in disc_cols}
missing_generated = {c: 0 for c in disc_cols}
rows = 0
with open_csv(args.generated) as f:
@@ -172,7 +181,11 @@ def main():
if ref_paths:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
tok = row.get(c, None)
if tok is None:
missing_generated[c] += 1
continue
if tok not in vocab_sets[c]:
disc_invalid[c] += 1
cont_summary = finalize_stats(cont_stats)
@@ -192,6 +205,7 @@ def main():
"continuous_summary": cont_summary,
"continuous_error": cont_err,
"discrete_invalid_counts": disc_invalid,
"missing_generated_columns": {k: v for k, v in missing_generated.items() if v > 0},
}
# Optional richer metrics using reference data
@@ -212,7 +226,7 @@ def main():
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
tok = row.get(c, "")
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
loaded = 0
@@ -228,7 +242,7 @@ def main():
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
tok = row.get(c, "")
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
loaded += 1
if args.max_rows and loaded >= args.max_rows:

View File

@@ -799,8 +799,10 @@ def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_row
break
return vals
ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_paths = []
if reference_arg:
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_rows = []
if ref_paths:
idx = max(0, min(ref_index, len(ref_paths) - 1))
@@ -882,7 +884,11 @@ def temporal_only_cont_values(reference_arg, type_config_path, cont_stats, seq_l
raise SystemExit("use_temporal_stage1 is not enabled in config_used/type-config; cannot plot temporal-only")
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
def load_torch_state(path: str, device):
try:
@@ -1196,7 +1202,11 @@ def lines_temporal_matplotlib(reference_arg, type_config_path, cont_stats, featu
out_dir = str((Path(__file__).resolve().parent / "results").resolve())
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
used_cfg_path = Path(out_dir) / "config_used.json"
if used_cfg_path.exists():
@@ -1306,6 +1316,7 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1365,7 +1376,8 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title(feat, fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1418,6 +1430,7 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1474,7 +1487,8 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="temporal_only")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1526,6 +1540,7 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1585,7 +1600,8 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")

121
figures/README.md Normal file
View File

@@ -0,0 +1,121 @@
# Figures 说明
本目录里的图片主要由 [plot_benchmark.py](file:///f:/Development/modbus_diffusion/mask-ddpm/mask-ddpm/example/plot_benchmark.py) 生成,用来对比:
- generated完整管线采样导出的 `example/results/generated.csv`(或 post 后的 `generated_post.csv`
- temporal_only只使用 `temporal.pt`temporal stage1的输出
- real`--reference` 指定的真实训练数据(通常是 `example/config.json` 里的 `data_glob` / `data_path` 指向的 `train*.csv.gz`
## 通用图例
- 蓝色模型输出generated 或 temporal_only具体看文件名
- 红色真实数据real
- 文件名包含 `with_real`:一定画了红色 real
- 文件名包含 `generated_only`:只画模型输出(没有 real
- 文件名包含 `temporal`:使用 `temporal.pt`(只时序 stage1
- 文件名包含 `ref384`:真实数据读取行数被限制到 384通常用于和 generated 行数对齐)
## 每张图是干嘛的
### 总览与指标
- [benchmark_panel.svg](./benchmark_panel.svg)
- 作用:一张图看完整评估概况(四个 panel
- A连续特征均值的“generated vs real”一致性轮廓按 real min/max 做范围归一化)
- B每个特征的 KS 分布差异(越小越好)
- C不同训练文件之间的均值漂移数据集 shift
- D多随机种子鲁棒性 + 指标历史曲线
- 生成:`python example/plot_benchmark.py --figure panel`
- [benchmark_metrics.svg](./benchmark_metrics.svg)
- 作用:只看 3 个核心指标在不同 seed 下的均值/方差(越小越好)
- 指标avg_ks连续、avg_jsd离散、avg_lag1_diff滞后 1 自相关差异)
- 生成:`python example/plot_benchmark.py --figure summary`
- [ranked_ks.svg](./ranked_ks.svg)
- 作用:定位“哪些特征拖累了平均 KS”
- 左侧top-K KS 最大的特征
- 右侧:移除最差特征后,平均 KS 会下降多少
- 生成:`python example/plot_benchmark.py --figure ranked_ks`
### 线图(时间序列形状)
- [lines.svg](./lines.svg)
- 作用:对比 generated vs real 的时间序列曲线(选定若干特征)
- 输入:`example/results/generated.csv` + `--reference` 指向的 real
- 生成:`python example/plot_benchmark.py --figure lines --reference example/config.json`
- [lines_generated_type4.svg](./lines_generated_type4.svg)
- 作用Type4`P1_PIT02,P2_SIT02,P1_FT03`)的 generated 时间序列(可能不含 real取决于生成时的参数
- 生成:通常来自 `--figure lines --lines-features ... --out ...`
- [lines_generated_type4_with_real.svg](./lines_generated_type4_with_real.svg)
- 作用Type4 的 generated vs real 时间序列对比
- 生成(示例):
- `python example/plot_benchmark.py --figure lines --generated example/results/generated.csv --reference example/config.json --cont-stats example/results/cont_stats.json --lines-features P1_PIT02,P2_SIT02,P1_FT03 --lines-max-rows 2048 --out figures/lines_generated_type4_with_real.svg`
- [lines_temporal.svg](./lines_temporal.svg)
- 作用temporal_only vs real 的时间序列曲线(只时序 stage1
- 输入:`temporal.pt` + `--reference` 指向的 real
- 生成:`python example/plot_benchmark.py --figure lines_temporal --reference example/config.json --temporal-pt <path/to/temporal.pt>`
- [lines_temporal_type4.svg](./lines_temporal_type4.svg)
- 作用Type4 的 temporal_only 时间序列(可能不含 real取决于生成时的参数
- [lines_temporal_type4_with_real.svg](./lines_temporal_type4_with_real.svg)
- 作用Type4 的 temporal_only vs real 时间序列对比
### 分布图CDF对齐边缘/中位数/尾部)
- [cdf_grid.svg](./cdf_grid.svg)
- 作用:连续特征的 CDF 网格generated vs real
- 生成:`python example/plot_benchmark.py --figure cdf_grid --reference example/config.json`
- [cdf_grid_all.svg](./cdf_grid_all.svg)
- 作用:`cdf_grid.svg` 的“更多特征/全量特征”版本(通常是 `--cdf-max-features 0` 或更大的上限)
- [cdf_grid_types.svg](./cdf_grid_types.svg)
- 作用:按 Type1~Type6 分组的连续特征 CDFgenerated vs real
- 生成:`python example/plot_benchmark.py --figure cdf_grid_types --reference example/config.json --type-config example/config.json`
- [cdf_grid_types_generated_only.svg](./cdf_grid_types_generated_only.svg)
- 作用:按类型分组的连续特征 CDF只画 generated没有 real
- [cdf_grid_types_generated_with_real.svg](./cdf_grid_types_generated_with_real.svg)
- 作用:按类型分组的连续特征 CDFgenerated vs real
- [cdf_grid_types_ref384.svg](./cdf_grid_types_ref384.svg)
- 作用:按类型分组的连续特征 CDF但 real 侧只读取 384 行(方便与 generated 行数对齐做“同样本量”对比)
- [cdf_grid_types_temporal.svg](./cdf_grid_types_temporal.svg)
- 作用:按类型分组的连续特征 CDFtemporal_only vs real
- [cdf_grid_types_temporal_fast.svg](./cdf_grid_types_temporal_fast.svg)
- 作用:`cdf_grid_types_temporal` 的快速版本(通常 seq_len/bins 更小)
- [cdf_grid_types_temporal_ref384.svg](./cdf_grid_types_temporal_ref384.svg)
- 作用:按类型分组的 temporal_only CDF但 real 侧只读取 384 行
- [cdf_grid_types_temporal_with_real.svg](./cdf_grid_types_temporal_with_real.svg)
- 作用:按类型分组的 temporal_only vs real CDF
### 离散特征分布
- [disc_grid.svg](./disc_grid.svg)
- 作用离散特征的分布网格generated vs real
- 生成:`python example/plot_benchmark.py --figure disc_grid --reference example/config.json`
- [disc_points.svg](./disc_points.svg)
- 作用离散特征的“类别占比对比点图”每个类别两点generated 与 real
- 生成:`python example/plot_benchmark.py --figure disc_points --reference example/config.json`
### Smoke / 临时检查图
以下 `*_smoke_*.svg` 多数是临时/快速检查用输出(用于验证脚本跑通、参数组合是否正常),不一定是最终报告图:
- [_smoke_cdf_all.svg](./_smoke_cdf_all.svg)
- [_smoke_cdf_types_temporal_exclude.svg](./_smoke_cdf_types_temporal_exclude.svg)
- [_smoke_disc_all.svg](./_smoke_disc_all.svg)
- [_smoke_lines_temporal_cond.svg](./_smoke_lines_temporal_cond.svg)
- [_smoke_lines_temporal_seeded.svg](./_smoke_lines_temporal_seeded.svg)

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 540 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 628 KiB

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 629 KiB

After

Width:  |  Height:  |  Size: 540 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 629 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 145 KiB

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 191 KiB

After

Width:  |  Height:  |  Size: 103 KiB

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 196 KiB