This commit is contained in:
2026-01-23 12:40:20 +08:00
parent 97e47be051
commit 5547e89287
4 changed files with 263 additions and 10 deletions

View File

@@ -54,6 +54,25 @@ One-click pipeline (prepare -> train -> export -> eval -> plot):
python example/run_pipeline.py --device auto python example/run_pipeline.py --device auto
``` ```
## Ablation: Feature Split Variants
Generate alternative continuous/discrete splits (baseline/strict/loose):
```
python example/ablation_splits.py --data-glob "../../dataset/hai/hai-21.03/train*.csv.gz"
```
Then run prepare/train with a chosen split:
```
python example/prepare_data.py --split-path example/results/ablation_splits/split_strict.json
python example/train.py --config example/config.json --device cuda
```
Update `example/config.json` to point `split_path` at the chosen split file.
One-click ablation (runs baseline/strict/loose end-to-end):
```
python example/run_ablation.py --device cuda
```
## Notes ## Notes
- Heuristic: integer-like values with low cardinality (<=10) are treated as - Heuristic: integer-like values with low cardinality (<=10) are treated as
discrete. All other numeric columns are continuous. discrete. All other numeric columns are continuous.

114
example/ablation_splits.py Normal file
View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python3
"""Generate multiple continuous/discrete splits for ablation."""
import argparse
import json
from pathlib import Path
from data_utils import iter_rows
from platform_utils import resolve_path, safe_path, ensure_dir
def parse_args():
parser = argparse.ArgumentParser(description="Generate split variants for ablation.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--max-rows", type=int, default=50000)
parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits"))
parser.add_argument("--time-col", default="time")
return parser.parse_args()
def analyze_columns(paths, max_rows, time_col):
stats = {}
rows = 0
for path in paths:
for row in iter_rows(path):
rows += 1
for c, v in row.items():
if c == time_col:
continue
st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0})
if v is None or v == "":
continue
st["count"] += 1
if st["numeric"]:
try:
fv = float(v)
except Exception:
st["numeric"] = False
st["int_like"] = False
st["unique"].add(v)
continue
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
st["int_like"] = False
if len(st["unique"]) < 200:
st["unique"].add(fv)
else:
if len(st["unique"]) < 200:
st["unique"].add(v)
if max_rows is not None and rows >= max_rows:
return stats
return stats
def build_split(stats, time_col, int_ratio=0.98, max_unique=20):
cont = []
disc = []
for c, st in stats.items():
if c == time_col:
continue
if st["count"] == 0:
continue
if not st["numeric"]:
disc.append(c)
continue
unique_count = len(st["unique"])
# if values look integer-like and low unique => discrete
if st["int_like"] and unique_count <= max_unique:
disc.append(c)
else:
cont.append(c)
return cont, disc
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
glob_path = resolve_path(base_dir, args.data_glob)
paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
if not paths:
raise SystemExit("no train files found under %s" % str(glob_path))
paths = [safe_path(p) for p in paths]
stats = analyze_columns(paths, args.max_rows, args.time_col)
ensure_dir(args.out_dir)
# baseline (current heuristic)
cont, disc = build_split(stats, args.time_col, max_unique=10)
baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)}
# stricter discrete
cont_s, disc_s = build_split(stats, args.time_col, max_unique=5)
strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)}
# looser discrete
cont_l, disc_l = build_split(stats, args.time_col, max_unique=30)
loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)}
out_dir = Path(args.out_dir)
with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f:
json.dump(baseline, f, indent=2)
with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f:
json.dump(strict, f, indent=2)
with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f:
json.dump(loose, f, indent=2)
print("wrote", out_dir / "split_baseline.json")
print("wrote", out_dir / "split_strict.json")
print("wrote", out_dir / "split_loose.json")
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Prepare vocab and normalization stats for HAI 21.03.""" """Prepare vocab and normalization stats for HAI 21.03."""
import argparse
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from data_utils import compute_cont_stats, build_disc_stats, load_split from data_utils import compute_cont_stats, build_disc_stats, load_split
from platform_utils import safe_path, ensure_dir from platform_utils import safe_path, ensure_dir, resolve_path
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent REPO_DIR = BASE_DIR.parent.parent
@@ -16,22 +17,39 @@ OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json" OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
def main(max_rows: Optional[int] = None): def parse_args():
split = load_split(safe_path(SPLIT_PATH)) parser = argparse.ArgumentParser(description="Prepare vocab and stats for HAI.")
parser.add_argument("--data-glob", default=str(DATA_GLOB), help="Glob for train CSVs")
parser.add_argument("--split-path", default=str(SPLIT_PATH), help="Split JSON path")
parser.add_argument("--out-stats", default=str(OUT_STATS), help="Output stats JSON")
parser.add_argument("--out-vocab", default=str(OUT_VOCAB), help="Output vocab JSON")
parser.add_argument("--max-rows", type=int, default=50000, help="Row cap for speed")
return parser.parse_args()
def main(max_rows: Optional[int] = None, split_path: Optional[str] = None, data_glob: Optional[str] = None,
out_stats: Optional[str] = None, out_vocab: Optional[str] = None):
split_path = split_path or str(SPLIT_PATH)
data_glob = data_glob or str(DATA_GLOB)
out_stats = out_stats or str(OUT_STATS)
out_vocab = out_vocab or str(OUT_VOCAB)
split = load_split(safe_path(split_path))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col] cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz")) glob_path = resolve_path(BASE_DIR, data_glob)
data_paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
if not data_paths: if not data_paths:
raise SystemExit("no train files found under %s" % str(DATA_GLOB)) raise SystemExit("no train files found under %s" % str(glob_path))
data_paths = [safe_path(p) for p in data_paths] data_paths = [safe_path(p) for p in data_paths]
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent) ensure_dir(Path(out_stats).parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: with open(safe_path(out_stats), "w", encoding="utf-8") as f:
json.dump( json.dump(
{ {
"mean": mean, "mean": mean,
@@ -46,10 +64,16 @@ def main(max_rows: Optional[int] = None):
indent=2, indent=2,
) )
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f: with open(safe_path(out_vocab), "w", encoding="utf-8") as f:
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2) json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
# Default: sample 50000 rows for speed. Set to None for full scan. args = parse_args()
main(max_rows=50000) main(
max_rows=args.max_rows,
split_path=args.split_path,
data_glob=args.data_glob,
out_stats=args.out_stats,
out_vocab=args.out_vocab,
)

96
example/run_ablation.py Normal file
View File

@@ -0,0 +1,96 @@
#!/usr/bin/env python3
"""One-click ablation runner for split variants."""
import argparse
import json
import subprocess
import sys
from pathlib import Path
from platform_utils import safe_path, is_windows
def run(cmd):
cmd = [safe_path(c) for c in cmd]
if is_windows():
subprocess.run(cmd, check=True, shell=False)
else:
subprocess.run(cmd, check=True)
def parse_args():
parser = argparse.ArgumentParser(description="Run ablations over split variants.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--device", default="auto")
parser.add_argument("--config", default=str(base_dir / "config.json"))
parser.add_argument("--data-glob", default=str(base_dir.parent.parent / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--max-rows", type=int, default=50000)
return parser.parse_args()
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
results_dir = base_dir / "results"
splits_dir = results_dir / "ablation_splits"
splits_dir.mkdir(parents=True, exist_ok=True)
# generate splits
run([sys.executable, str(base_dir / "ablation_splits.py"), "--data-glob", args.data_glob, "--max-rows", str(args.max_rows)])
split_files = [
splits_dir / "split_baseline.json",
splits_dir / "split_strict.json",
splits_dir / "split_loose.json",
]
for split_path in split_files:
tag = split_path.stem
run([
sys.executable,
str(base_dir / "prepare_data.py"),
"--data-glob",
args.data_glob,
"--split-path",
str(split_path),
"--out-stats",
str(results_dir / f"cont_stats_{tag}.json"),
"--out-vocab",
str(results_dir / f"disc_vocab_{tag}.json"),
])
# load base config, override split/stats/vocab/out_dir
cfg_path = Path(args.config)
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
cfg["split_path"] = str(split_path)
cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json")
cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json")
cfg["out_dir"] = str(results_dir / f"ablation_{tag}")
temp_cfg = results_dir / f"config_{tag}.json"
temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
run([sys.executable, str(base_dir / "train.py"), "--config", str(temp_cfg), "--device", args.device])
run([
sys.executable,
str(base_dir / "export_samples.py"),
"--include-time",
"--device",
args.device,
"--config",
str(temp_cfg),
"--timesteps",
str(cfg.get("timesteps", 400)),
"--seq-len",
str(cfg.get("sample_seq_len", cfg.get("seq_len", 128))),
"--batch-size",
str(cfg.get("sample_batch_size", 8)),
"--clip-k",
str(cfg.get("clip_k", 3.0)),
"--use-ema",
])
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--out", str(results_dir / f"ablation_{tag}" / "eval.json")])
if __name__ == "__main__":
main()