This commit is contained in:
MZ YANG
2026-02-12 01:46:47 +08:00
parent 26b9b8a447
commit f1afd4bf38
11 changed files with 77395 additions and 13341 deletions

View File

@@ -799,8 +799,10 @@ def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_row
break
return vals
ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_paths = []
if reference_arg:
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_rows = []
if ref_paths:
idx = max(0, min(ref_index, len(ref_paths) - 1))
@@ -882,7 +884,11 @@ def temporal_only_cont_values(reference_arg, type_config_path, cont_stats, seq_l
raise SystemExit("use_temporal_stage1 is not enabled in config_used/type-config; cannot plot temporal-only")
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
def load_torch_state(path: str, device):
try:
@@ -1196,7 +1202,11 @@ def lines_temporal_matplotlib(reference_arg, type_config_path, cont_stats, featu
out_dir = str((Path(__file__).resolve().parent / "results").resolve())
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
used_cfg_path = Path(out_dir) / "config_used.json"
if used_cfg_path.exists():
@@ -1306,6 +1316,7 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1365,7 +1376,8 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title(feat, fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1418,6 +1430,7 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1474,7 +1487,8 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="temporal_only")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1526,6 +1540,7 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1585,7 +1600,8 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")