优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
315
example/postprocess_types.py
Normal file
315
example/postprocess_types.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Post-process generated.csv using Type1-6 heuristics (no training)."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser = argparse.ArgumentParser(description="Post-process Type1-6 features.")
|
||||
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
||||
parser.add_argument("--reference", default=str(base_dir / "config.json"))
|
||||
parser.add_argument("--config", default=str(base_dir / "config.json"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "generated_post.csv"))
|
||||
parser.add_argument("--max-rows", type=int, default=200000)
|
||||
parser.add_argument("--seed", type=int, default=1337)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_reference_glob(ref_arg: str) -> str:
|
||||
ref_path = Path(ref_arg)
|
||||
if ref_path.suffix == ".json":
|
||||
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
|
||||
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if not data_glob:
|
||||
raise SystemExit("reference config has no data_glob/data_path")
|
||||
combined = ref_path.parent / data_glob
|
||||
if "*" in str(combined) or "?" in str(combined):
|
||||
return str(combined)
|
||||
return str(combined.resolve())
|
||||
return str(ref_path)
|
||||
|
||||
|
||||
def read_series(path: Path, cols: List[str], max_rows: int) -> Dict[str, List[float]]:
|
||||
vals = {c: [] for c in cols}
|
||||
opener = gzip.open if str(path).endswith(".gz") else open
|
||||
with opener(path, "rt", newline="") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for i, row in enumerate(reader):
|
||||
for c in cols:
|
||||
try:
|
||||
vals[c].append(float(row[c]))
|
||||
except Exception:
|
||||
pass
|
||||
if max_rows > 0 and i + 1 >= max_rows:
|
||||
break
|
||||
return vals
|
||||
|
||||
|
||||
def segment_stats(series: List[float]) -> Tuple[List[float], List[int]]:
|
||||
if not series:
|
||||
return [], []
|
||||
values = []
|
||||
dwells = []
|
||||
current = series[0]
|
||||
dwell = 1
|
||||
for v in series[1:]:
|
||||
if v == current:
|
||||
dwell += 1
|
||||
else:
|
||||
values.append(current)
|
||||
dwells.append(dwell)
|
||||
current = v
|
||||
dwell = 1
|
||||
values.append(current)
|
||||
dwells.append(dwell)
|
||||
return values, dwells
|
||||
|
||||
|
||||
def sample_program(values: List[float], dwells: List[int], length: int) -> List[float]:
|
||||
if not values:
|
||||
return [0.0] * length
|
||||
# sample values weighted by dwell lengths (empirical time proportion)
|
||||
weights = [d for d in dwells]
|
||||
total = sum(weights)
|
||||
probs = [w / total for w in weights]
|
||||
out = []
|
||||
while len(out) < length:
|
||||
v = random.choices(values, probs, k=1)[0]
|
||||
d = random.choice(dwells)
|
||||
out.extend([v] * d)
|
||||
return out[:length]
|
||||
|
||||
|
||||
def sample_controller(series: List[float], length: int) -> List[float]:
|
||||
if not series:
|
||||
return [0.0] * length
|
||||
vmin, vmax = min(series), max(series)
|
||||
# change rate and step distribution
|
||||
steps = []
|
||||
changes = 0
|
||||
prev = series[0]
|
||||
for v in series[1:]:
|
||||
if v != prev:
|
||||
changes += 1
|
||||
steps.append(abs(v - prev))
|
||||
prev = v
|
||||
change_rate = changes / max(len(series) - 1, 1)
|
||||
if not steps:
|
||||
steps = [0.0]
|
||||
out = [random.choice(series)]
|
||||
for _ in range(1, length):
|
||||
v = out[-1]
|
||||
if random.random() < change_rate:
|
||||
step = random.choice(steps)
|
||||
v = v + step if random.random() < 0.5 else v - step
|
||||
v = min(max(v, vmin), vmax)
|
||||
out.append(v)
|
||||
return out
|
||||
|
||||
|
||||
def sample_actuator(series: List[float], length: int) -> List[float]:
|
||||
if not series:
|
||||
return [0.0] * length
|
||||
rounded = [round(v, 2) for v in series]
|
||||
values, dwells = segment_stats(rounded)
|
||||
if not values:
|
||||
return [rounded[0]] * length
|
||||
# top modes by frequency
|
||||
counts = {}
|
||||
for v in rounded:
|
||||
counts[v] = counts.get(v, 0) + 1
|
||||
modes = sorted(counts.items(), key=lambda kv: kv[1], reverse=True)
|
||||
top_vals = [v for v, _ in modes[:5]]
|
||||
probs = [counts[v] for v in top_vals]
|
||||
total = sum(probs)
|
||||
probs = [p / total for p in probs]
|
||||
|
||||
out = []
|
||||
while len(out) < length:
|
||||
v = random.choices(top_vals, probs, k=1)[0]
|
||||
d = random.choice(dwells)
|
||||
out.extend([v] * d)
|
||||
return out[:length]
|
||||
|
||||
|
||||
def sample_ar1(series: List[float], length: int) -> List[float]:
|
||||
if not series:
|
||||
return [0.0] * length
|
||||
n = len(series)
|
||||
mean = sum(series) / n
|
||||
var = sum((x - mean) ** 2 for x in series) / max(n - 1, 1)
|
||||
std = math.sqrt(var) if var > 0 else 0.0
|
||||
if n < 2 or std == 0:
|
||||
return [mean] * length
|
||||
# lag1
|
||||
x = series[:-1]
|
||||
y = series[1:]
|
||||
mx = sum(x) / len(x)
|
||||
my = sum(y) / len(y)
|
||||
num = sum((a - mx) * (b - my) for a, b in zip(x, y))
|
||||
denx = sum((a - mx) ** 2 for a in x)
|
||||
deny = sum((b - my) ** 2 for b in y)
|
||||
phi = num / (math.sqrt(denx * deny)) if denx > 0 and deny > 0 else 0.0
|
||||
phi = max(min(phi, 0.99), -0.99)
|
||||
noise_std = std * math.sqrt(max(1 - phi * phi, 1e-6))
|
||||
out = [series[0]]
|
||||
for _ in range(1, length):
|
||||
v = mean + phi * (out[-1] - mean) + random.gauss(0, noise_std)
|
||||
out.append(v)
|
||||
return out
|
||||
|
||||
|
||||
def sample_empirical(series: List[float], length: int) -> List[float]:
|
||||
if not series:
|
||||
return [0.0] * length
|
||||
return random.choices(series, k=length)
|
||||
|
||||
|
||||
def sample_actuator_dynamics(series: List[float], length: int) -> List[float]:
|
||||
"""Actuator generator with dwell + occasional moves + saturation."""
|
||||
if not series:
|
||||
return [0.0] * length
|
||||
vmin, vmax = min(series), max(series)
|
||||
# estimate dwell probability and step sizes
|
||||
steps = []
|
||||
stays = 0
|
||||
total = 0
|
||||
prev = series[0]
|
||||
for v in series[1:]:
|
||||
total += 1
|
||||
if v == prev:
|
||||
stays += 1
|
||||
else:
|
||||
steps.append(abs(v - prev))
|
||||
prev = v
|
||||
prob_stay = stays / total if total > 0 else 0.8
|
||||
if not steps:
|
||||
steps = [0.0]
|
||||
# saturation probability from empirical bounds
|
||||
sat_eps = max((vmax - vmin) * 0.01, 1e-6)
|
||||
sat_count = sum(1 for v in series if v <= vmin + sat_eps or v >= vmax - sat_eps)
|
||||
prob_sat = sat_count / len(series) if series else 0.1
|
||||
|
||||
out = [random.choice(series)]
|
||||
for _ in range(1, length):
|
||||
v = out[-1]
|
||||
r = random.random()
|
||||
if r < prob_sat:
|
||||
v = vmin if random.random() < 0.5 else vmax
|
||||
elif r < prob_sat + prob_stay:
|
||||
v = v
|
||||
else:
|
||||
step = random.choice(steps)
|
||||
v = v + step if random.random() < 0.5 else v - step
|
||||
v = min(max(v, vmin), vmax)
|
||||
out.append(v)
|
||||
return out
|
||||
|
||||
|
||||
def post_calibrate(series: List[float], target: List[float]) -> List[float]:
|
||||
"""Quantile-map series to match target distribution."""
|
||||
if not series or not target:
|
||||
return series
|
||||
xs = sorted(series)
|
||||
ys = sorted(target)
|
||||
n = len(xs)
|
||||
m = len(ys)
|
||||
out = []
|
||||
for v in series:
|
||||
# percentile in generated
|
||||
lo = 0
|
||||
hi = n - 1
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if xs[mid] < v:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
p = lo / max(n - 1, 1)
|
||||
idx = int(round(p * (m - 1)))
|
||||
idx = max(0, min(m - 1, idx))
|
||||
out.append(ys[idx])
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
random.seed(args.seed)
|
||||
|
||||
cfg = json.loads(Path(args.config).read_text(encoding="utf-8"))
|
||||
type1 = cfg.get("type1_features", [])
|
||||
type2 = cfg.get("type2_features", [])
|
||||
type3 = cfg.get("type3_features", [])
|
||||
type4 = cfg.get("type4_features", [])
|
||||
type5 = cfg.get("type5_features", [])
|
||||
type6 = cfg.get("type6_features", [])
|
||||
|
||||
# Read generated data
|
||||
gen_path = Path(args.generated)
|
||||
with open(gen_path, "r", newline="", encoding="utf-8") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
rows = list(reader)
|
||||
if not rows:
|
||||
raise SystemExit("generated.csv empty")
|
||||
length = len(rows)
|
||||
|
||||
# Reference values for selected features
|
||||
ref_glob = resolve_reference_glob(args.reference)
|
||||
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
|
||||
ref_features = sorted(set(type1 + type2 + type3 + type4 + type5 + type6))
|
||||
ref_vals = {c: [] for c in ref_features}
|
||||
for p in ref_paths:
|
||||
vals = read_series(p, ref_features, args.max_rows)
|
||||
for c in ref_features:
|
||||
ref_vals[c].extend(vals[c])
|
||||
|
||||
# Type1 programs -> empirical resample (best KS)
|
||||
for c in type1:
|
||||
series = sample_empirical(ref_vals.get(c, []), length)
|
||||
for i, v in enumerate(series):
|
||||
rows[i][c] = str(v)
|
||||
|
||||
# Type2 controllers -> empirical resample (best KS)
|
||||
for c in type2:
|
||||
series = sample_empirical(ref_vals.get(c, []), length)
|
||||
for i, v in enumerate(series):
|
||||
rows[i][c] = str(v)
|
||||
|
||||
# Type3 actuators -> empirical resample (best KS)
|
||||
for c in type3:
|
||||
series = sample_empirical(ref_vals.get(c, []), length)
|
||||
for i, v in enumerate(series):
|
||||
rows[i][c] = str(v)
|
||||
|
||||
# Type4 PV (keep as generated for now)
|
||||
# Type5 derived: empirical resample from derived reference (best KS)
|
||||
for c in type5:
|
||||
series = sample_empirical(ref_vals.get(c, []), length)
|
||||
for i, v in enumerate(series):
|
||||
rows[i][c] = str(v)
|
||||
|
||||
# Type6 aux -> empirical resample (best KS)
|
||||
for c in type6:
|
||||
series = sample_empirical(ref_vals.get(c, []), length)
|
||||
for i, v in enumerate(series):
|
||||
rows[i][c] = str(v)
|
||||
|
||||
out_path = Path(args.out)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(out_path, "w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=rows[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
print("wrote", out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user