This commit is contained in:
MZ YANG
2026-02-12 01:46:47 +08:00
parent 26b9b8a447
commit f1afd4bf38
11 changed files with 77395 additions and 13341 deletions

View File

@@ -107,6 +107,14 @@ def lag1_corr(values: List[float]) -> float:
def resolve_reference_paths(path: str) -> List[str]:
if not path:
return []
if path.endswith(".json") and Path(path).exists():
try:
cfg = load_json(path)
ref = cfg.get("data_glob") or cfg.get("data_path") or ""
if ref:
return resolve_reference_paths(str(ref))
except Exception:
return []
if any(ch in path for ch in ["*", "?", "["]):
base = Path(path).parent.resolve()
pat = Path(path).name
@@ -151,10 +159,11 @@ def main():
std_ref = stats_json.get("raw_std", stats_json.get("std"))
transforms = stats_json.get("transform", {})
vocab = load_json(args.vocab)["vocab"]
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
vocab_sets = {c: set(vocab.get(c, {}).keys()) for c in disc_cols}
cont_stats = init_stats(cont_cols)
disc_invalid = {c: 0 for c in disc_cols}
missing_generated = {c: 0 for c in disc_cols}
rows = 0
with open_csv(args.generated) as f:
@@ -172,7 +181,11 @@ def main():
if ref_paths:
pass
for c in disc_cols:
if row[c] not in vocab_sets[c]:
tok = row.get(c, None)
if tok is None:
missing_generated[c] += 1
continue
if tok not in vocab_sets[c]:
disc_invalid[c] += 1
cont_summary = finalize_stats(cont_stats)
@@ -192,6 +205,7 @@ def main():
"continuous_summary": cont_summary,
"continuous_error": cont_err,
"discrete_invalid_counts": disc_invalid,
"missing_generated_columns": {k: v for k, v in missing_generated.items() if v > 0},
}
# Optional richer metrics using reference data
@@ -212,7 +226,7 @@ def main():
except Exception:
gen_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
tok = row.get(c, "")
gen_disc[c][tok] = gen_disc[c].get(tok, 0) + 1
loaded = 0
@@ -228,7 +242,7 @@ def main():
except Exception:
ref_cont[c].append(0.0)
for c in disc_cols:
tok = row[c]
tok = row.get(c, "")
ref_disc[c][tok] = ref_disc[c].get(tok, 0) + 1
loaded += 1
if args.max_rows and loaded >= args.max_rows:

View File

@@ -799,8 +799,10 @@ def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_row
break
return vals
ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_paths = []
if reference_arg:
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_rows = []
if ref_paths:
idx = max(0, min(ref_index, len(ref_paths) - 1))
@@ -882,7 +884,11 @@ def temporal_only_cont_values(reference_arg, type_config_path, cont_stats, seq_l
raise SystemExit("use_temporal_stage1 is not enabled in config_used/type-config; cannot plot temporal-only")
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
def load_torch_state(path: str, device):
try:
@@ -1196,7 +1202,11 @@ def lines_temporal_matplotlib(reference_arg, type_config_path, cont_stats, featu
out_dir = str((Path(__file__).resolve().parent / "results").resolve())
temporal_pt = temporal_pt_path or str(Path(out_dir) / "temporal.pt")
temporal_pt = str((cfg_base / temporal_pt).resolve()) if not Path(temporal_pt).is_absolute() else temporal_pt
temporal_pt_p = Path(temporal_pt)
if not temporal_pt_p.is_absolute():
temporal_pt = str(temporal_pt_p.resolve()) if temporal_pt_p.exists() else str((cfg_base / temporal_pt_p).resolve())
else:
temporal_pt = str(temporal_pt_p)
used_cfg_path = Path(out_dir) / "config_used.json"
if used_cfg_path.exists():
@@ -1306,6 +1316,7 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1365,7 +1376,8 @@ def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features,
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title(feat, fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1418,6 +1430,7 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1474,7 +1487,8 @@ def cdf_grid_types_temporal_matplotlib(reference_arg, cont_stats, type_config_pa
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="temporal_only")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
@@ -1526,6 +1540,7 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
if ref_mode == "index" and ref_paths:
idx = max(0, min(int(ref_index), len(ref_paths) - 1))
ref_paths = [ref_paths[idx]]
has_ref = bool(ref_paths)
edges_by_feat = {}
g_hist_by_feat = {}
@@ -1585,7 +1600,8 @@ def cdf_grid_types_matplotlib(generated_csv_path, reference_arg, cont_stats, typ
xr = edges[1:]
yr = ecdf_from_hist(r_hist_by_feat[feat])
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
if has_ref:
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title("{f} (Type {t})".format(f=feat, t=tmap.get(feat)), fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")