back
This commit is contained in:
@@ -72,3 +72,4 @@ python example/run_pipeline.py --device auto
|
|||||||
- The script only samples the first 5000 rows to stay fast.
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
|
- Optional two-stage temporal model (`use_temporal_stage1`) trains a GRU trend backbone first, then diffusion models residuals.
|
||||||
|
|||||||
@@ -38,6 +38,12 @@
|
|||||||
"cont_target": "x0",
|
"cont_target": "x0",
|
||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
|
"use_temporal_stage1": true,
|
||||||
|
"temporal_hidden_dim": 256,
|
||||||
|
"temporal_num_layers": 1,
|
||||||
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_epochs": 2,
|
||||||
|
"temporal_lr": 0.001,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
|
|
||||||
@@ -140,6 +140,10 @@ def main():
|
|||||||
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
||||||
cont_target = str(cfg.get("cont_target", "eps"))
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
|
||||||
model = HybridDiffusionModel(
|
model = HybridDiffusionModel(
|
||||||
cont_dim=len(cont_cols),
|
cont_dim=len(cont_cols),
|
||||||
@@ -163,6 +167,20 @@ def main():
|
|||||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
if use_temporal_stage1:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(device)
|
||||||
|
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||||
|
if not temporal_path.exists():
|
||||||
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
|
||||||
|
temporal_model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
@@ -189,6 +207,10 @@ def main():
|
|||||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||||
cond = cond_id
|
cond = cond_id
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
|
||||||
|
|
||||||
for t in reversed(range(args.timesteps)):
|
for t in reversed(range(args.timesteps)):
|
||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
@@ -225,6 +247,8 @@ def main():
|
|||||||
)
|
)
|
||||||
x_disc[:, :, i][mask] = sampled[mask]
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
if trend is not None:
|
||||||
|
x_cont = x_cont + trend
|
||||||
# move to CPU for export
|
# move to CPU for export
|
||||||
x_cont = x_cont.cpu()
|
x_cont = x_cont.cpu()
|
||||||
x_disc = x_disc.cpu()
|
x_disc = x_disc.cpu()
|
||||||
|
|||||||
@@ -66,6 +66,46 @@ class SinusoidalTimeEmbedding(nn.Module):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalGRUGenerator(nn.Module):
|
||||||
|
"""Stage-1 temporal generator for sequence trend."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout if num_layers > 1 else 0.0,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.out = nn.Linear(hidden_dim, input_dim)
|
||||||
|
|
||||||
|
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
|
||||||
|
if x.size(1) < 2:
|
||||||
|
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||||
|
inp = x[:, :-1, :]
|
||||||
|
out, _ = self.gru(inp)
|
||||||
|
pred_next = self.out(out)
|
||||||
|
trend = torch.zeros_like(x)
|
||||||
|
trend[:, 0, :] = x[:, 0, :]
|
||||||
|
trend[:, 1:, :] = pred_next
|
||||||
|
return trend, pred_next
|
||||||
|
|
||||||
|
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Autoregressively generate a backbone sequence."""
|
||||||
|
h = None
|
||||||
|
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||||
|
outputs = []
|
||||||
|
for _ in range(seq_len):
|
||||||
|
out, h = self.gru(prev.unsqueeze(1), h)
|
||||||
|
nxt = self.out(out.squeeze(1))
|
||||||
|
outputs.append(nxt.unsqueeze(1))
|
||||||
|
prev = nxt
|
||||||
|
return torch.cat(outputs, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class HybridDiffusionModel(nn.Module):
|
class HybridDiffusionModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ def main():
|
|||||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||||
else:
|
else:
|
||||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||||
|
run([sys.executable, str(base_dir / "summary_metrics.py")])
|
||||||
|
run([sys.executable, str(base_dir / "summary_metrics.py")])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
@@ -47,6 +47,10 @@ def main():
|
|||||||
cond_dim = int(cfg.get("cond_dim", 32))
|
cond_dim = int(cfg.get("cond_dim", 32))
|
||||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||||
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
cont_target = str(cfg.get("cont_target", "eps"))
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||||
@@ -92,6 +96,20 @@ def main():
|
|||||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
if use_temporal_stage1:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(DEVICE)
|
||||||
|
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||||
|
if not temporal_path.exists():
|
||||||
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
||||||
|
temporal_model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
@@ -110,6 +128,10 @@ def main():
|
|||||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||||
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
|
||||||
|
|
||||||
for t in reversed(range(timesteps)):
|
for t in reversed(range(timesteps)):
|
||||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
@@ -146,6 +168,8 @@ def main():
|
|||||||
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
||||||
x_disc[:, :, i][mask] = sampled[mask]
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
if trend is not None:
|
||||||
|
x_cont = x_cont + trend
|
||||||
print("sampled_cont_shape", tuple(x_cont.shape))
|
print("sampled_cont_shape", tuple(x_cont.shape))
|
||||||
print("sampled_disc_shape", tuple(x_disc.shape))
|
print("sampled_disc_shape", tuple(x_disc.shape))
|
||||||
|
|
||||||
|
|||||||
63
example/summary_metrics.py
Normal file
63
example/summary_metrics.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Print average metrics from eval.json and compare with previous run."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def mean(values):
|
||||||
|
return sum(values) / len(values) if values else None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_last_row(history_path: Path):
|
||||||
|
if not history_path.exists():
|
||||||
|
return None
|
||||||
|
rows = history_path.read_text(encoding="utf-8").strip().splitlines()
|
||||||
|
if len(rows) < 2:
|
||||||
|
return None
|
||||||
|
last = rows[-1].split(",")
|
||||||
|
if len(last) < 4:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"avg_ks": float(last[1]),
|
||||||
|
"avg_jsd": float(last[2]),
|
||||||
|
"avg_lag1_diff": float(last[3]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
eval_path = base_dir / "results" / "eval.json"
|
||||||
|
if not eval_path.exists():
|
||||||
|
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||||
|
|
||||||
|
obj = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||||
|
ks = list(obj.get("continuous_ks", {}).values())
|
||||||
|
jsd = list(obj.get("discrete_jsd", {}).values())
|
||||||
|
lag = list(obj.get("continuous_lag1_diff", {}).values())
|
||||||
|
|
||||||
|
avg_ks = mean(ks)
|
||||||
|
avg_jsd = mean(jsd)
|
||||||
|
avg_lag1 = mean(lag)
|
||||||
|
|
||||||
|
history_path = base_dir / "results" / "metrics_history.csv"
|
||||||
|
prev = parse_last_row(history_path)
|
||||||
|
|
||||||
|
if not history_path.exists():
|
||||||
|
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
|
||||||
|
with history_path.open("a", encoding="utf-8") as f:
|
||||||
|
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
|
||||||
|
|
||||||
|
print("avg_ks", avg_ks)
|
||||||
|
print("avg_jsd", avg_jsd)
|
||||||
|
print("avg_lag1_diff", avg_lag1)
|
||||||
|
|
||||||
|
if prev is not None:
|
||||||
|
print("delta_avg_ks", avg_ks - prev["avg_ks"])
|
||||||
|
print("delta_avg_jsd", avg_jsd - prev["avg_jsd"])
|
||||||
|
print("delta_avg_lag1_diff", avg_lag1 - prev["avg_lag1_diff"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
|||||||
from data_utils import load_split, windowed_batches
|
from data_utils import load_split, windowed_batches
|
||||||
from hybrid_diffusion import (
|
from hybrid_diffusion import (
|
||||||
HybridDiffusionModel,
|
HybridDiffusionModel,
|
||||||
|
TemporalGRUGenerator,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
q_sample_continuous,
|
q_sample_continuous,
|
||||||
q_sample_discrete,
|
q_sample_discrete,
|
||||||
@@ -64,6 +65,12 @@ DEFAULTS = {
|
|||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "eps", # eps | x0
|
"cont_target": "eps", # eps | x0
|
||||||
"cont_clamp_x0": 0.0,
|
"cont_clamp_x0": 0.0,
|
||||||
|
"use_temporal_stage1": True,
|
||||||
|
"temporal_hidden_dim": 256,
|
||||||
|
"temporal_num_layers": 1,
|
||||||
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_epochs": 2,
|
||||||
|
"temporal_lr": 1e-3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -194,6 +201,19 @@ def main():
|
|||||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||||
).to(device)
|
).to(device)
|
||||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||||
|
temporal_model = None
|
||||||
|
opt_temporal = None
|
||||||
|
if bool(config.get("use_temporal_stage1", False)):
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||||
|
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||||
|
).to(device)
|
||||||
|
opt_temporal = torch.optim.Adam(
|
||||||
|
temporal_model.parameters(),
|
||||||
|
lr=float(config.get("temporal_lr", config["lr"])),
|
||||||
|
)
|
||||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||||
|
|
||||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||||
@@ -208,6 +228,37 @@ def main():
|
|||||||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||||
json.dump(config, f, indent=2)
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
if temporal_model is not None and opt_temporal is not None:
|
||||||
|
for epoch in range(int(config.get("temporal_epochs", 1))):
|
||||||
|
for step, batch in enumerate(
|
||||||
|
windowed_batches(
|
||||||
|
data_paths,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
vocab,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
batch_size=int(config["batch_size"]),
|
||||||
|
seq_len=int(config["seq_len"]),
|
||||||
|
max_batches=int(config["max_batches"]),
|
||||||
|
return_file_id=False,
|
||||||
|
transforms=transforms,
|
||||||
|
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
x_cont, _ = batch
|
||||||
|
x_cont = x_cont.to(device)
|
||||||
|
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||||
|
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||||
|
opt_temporal.zero_grad()
|
||||||
|
temporal_loss.backward()
|
||||||
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||||
|
opt_temporal.step()
|
||||||
|
if step % int(config["log_every"]) == 0:
|
||||||
|
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
|
||||||
|
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||||
|
|
||||||
total_step = 0
|
total_step = 0
|
||||||
for epoch in range(int(config["epochs"])):
|
for epoch in range(int(config["epochs"])):
|
||||||
for step, batch in enumerate(
|
for step, batch in enumerate(
|
||||||
@@ -235,10 +286,17 @@ def main():
|
|||||||
x_cont = x_cont.to(device)
|
x_cont = x_cont.to(device)
|
||||||
x_disc = x_disc.to(device)
|
x_disc = x_disc.to(device)
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
temporal_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
trend, _ = temporal_model.forward_teacher(x_cont)
|
||||||
|
x_cont_resid = x_cont if trend is None else x_cont - trend
|
||||||
|
|
||||||
bsz = x_cont.size(0)
|
bsz = x_cont.size(0)
|
||||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||||
|
|
||||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||||
|
|
||||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
x_disc_t, mask = q_sample_discrete(
|
x_disc_t, mask = q_sample_discrete(
|
||||||
@@ -253,7 +311,7 @@ def main():
|
|||||||
|
|
||||||
cont_target = str(config.get("cont_target", "eps"))
|
cont_target = str(config.get("cont_target", "eps"))
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x0_target = x_cont
|
x0_target = x_cont_resid
|
||||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||||
loss_base = (eps_pred - x0_target) ** 2
|
loss_base = (eps_pred - x0_target) ** 2
|
||||||
@@ -308,11 +366,15 @@ def main():
|
|||||||
}
|
}
|
||||||
if ema is not None:
|
if ema is not None:
|
||||||
ckpt["ema"] = ema.state_dict()
|
ckpt["ema"] = ema.state_dict()
|
||||||
|
if temporal_model is not None:
|
||||||
|
ckpt["temporal"] = temporal_model.state_dict()
|
||||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||||
|
|
||||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||||
if ema is not None:
|
if ema is not None:
|
||||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||||
|
if temporal_model is not None:
|
||||||
|
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
221
report.md
Normal file
221
report.md
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report
|
||||||
|
# 工业控制系统流量混合扩散生成(HAI 21.03)— 项目报告
|
||||||
|
|
||||||
|
## 1. Project Goal / 项目目标
|
||||||
|
Build a **hybrid diffusion-based generator** for ICS traffic features, focusing on **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve:
|
||||||
|
- **Distributional fidelity** (continuous ranges + discrete frequencies)
|
||||||
|
- **Temporal consistency** (time correlation and sequence structure)
|
||||||
|
- **Field/logic consistency** for discrete protocol-like columns
|
||||||
|
|
||||||
|
构建一个用于 ICS 流量特征的**混合扩散生成模型**,处理**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要保持:
|
||||||
|
- **分布一致性**(连续值范围 + 离散频率)
|
||||||
|
- **时序一致性**(时间相关性与序列结构)
|
||||||
|
- **字段/逻辑一致性**(离散字段语义)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Data and Scope / 数据与范围
|
||||||
|
**Dataset used in current implementation:** HAI 21.03 (CSV feature traces).
|
||||||
|
|
||||||
|
**当前实现使用数据集:** HAI 21.03(CSV 特征序列)。
|
||||||
|
|
||||||
|
**Data path (default in config):**
|
||||||
|
- `dataset/hai/hai-21.03/train*.csv.gz`
|
||||||
|
|
||||||
|
**特征拆分(固定 schema):** `example/feature_split.json`
|
||||||
|
- Continuous features: sensor/process values
|
||||||
|
- Discrete features: binary/low-cardinality status/flag fields
|
||||||
|
- `time` is excluded from modeling
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. End-to-End Pipeline / 端到端流程
|
||||||
|
One command pipeline:
|
||||||
|
```
|
||||||
|
python example/run_all.py --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
Pipeline stages:
|
||||||
|
1) **Prepare data** (`example/prepare_data.py`)
|
||||||
|
2) **Train temporal backbone** (`example/train.py`, stage 1)
|
||||||
|
3) **Train diffusion on residuals** (`example/train.py`, stage 2)
|
||||||
|
4) **Generate samples** (`example/export_samples.py`)
|
||||||
|
5) **Evaluate** (`example/evaluate_generated.py`)
|
||||||
|
|
||||||
|
一键流程对应:数据准备 → 时序骨干训练 → 残差扩散训练 → 采样导出 → 评估。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Technical Architecture / 技术架构
|
||||||
|
|
||||||
|
### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心)
|
||||||
|
Defined in `example/hybrid_diffusion.py`.
|
||||||
|
|
||||||
|
**Inputs:**
|
||||||
|
- Continuous projection
|
||||||
|
- Discrete embeddings
|
||||||
|
- Time embedding (sinusoidal)
|
||||||
|
- Positional embedding (sequence index)
|
||||||
|
- Optional condition embedding (`file_id`)
|
||||||
|
|
||||||
|
**Backbone:**
|
||||||
|
- GRU (sequence modeling)
|
||||||
|
- Post LayerNorm + residual MLP
|
||||||
|
|
||||||
|
**Outputs:**
|
||||||
|
- Continuous head: predicts target (`eps` or `x0`)
|
||||||
|
- Discrete heads: logits per discrete column
|
||||||
|
|
||||||
|
**连续分支:** Gaussian diffusion
|
||||||
|
**离散分支:** Mask diffusion
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.2 Stage-1 Temporal Model (GRU) / 第一阶段时序模型(GRU)
|
||||||
|
A separate GRU models the **trend backbone** of continuous features. It is trained first using teacher forcing to predict the next step.
|
||||||
|
|
||||||
|
独立的 GRU 先学习连续特征的**趋势骨架**,使用 teacher forcing 进行逐步预测。
|
||||||
|
|
||||||
|
Trend definition:
|
||||||
|
```
|
||||||
|
trend = GRU(x)
|
||||||
|
residual = x - trend
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Diffusion Formulations / 扩散形式
|
||||||
|
|
||||||
|
### 5.1 Continuous Diffusion / 连续扩散
|
||||||
|
Forward process on residuals:
|
||||||
|
```
|
||||||
|
r_t = sqrt(a_bar_t) * r + sqrt(1 - a_bar_t) * eps
|
||||||
|
```
|
||||||
|
|
||||||
|
Targets supported:
|
||||||
|
- **eps prediction**
|
||||||
|
- **x0 prediction** (default)
|
||||||
|
|
||||||
|
Current config:
|
||||||
|
```
|
||||||
|
"cont_target": "x0"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Discrete Diffusion / 离散扩散
|
||||||
|
Mask diffusion with cosine schedule:
|
||||||
|
```
|
||||||
|
p(t) = 0.5 * (1 - cos(pi * t / T))
|
||||||
|
```
|
||||||
|
Mask-only cross-entropy is computed on masked positions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Loss Design / 损失设计
|
||||||
|
Total loss:
|
||||||
|
```
|
||||||
|
L = λ * L_cont + (1 − λ) * L_disc
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.1 Continuous Loss / 连续损失
|
||||||
|
- `eps` target: MSE(eps_pred, eps)
|
||||||
|
- `x0` target: MSE(x0_pred, x0)
|
||||||
|
- Optional inverse-variance weighting: `cont_loss_weighting = "inv_std"`
|
||||||
|
|
||||||
|
### 6.2 Discrete Loss / 离散损失
|
||||||
|
Cross-entropy on masked positions only.
|
||||||
|
|
||||||
|
### 6.3 Temporal Loss / 时序损失
|
||||||
|
Stage‑1 GRU predicts next step:
|
||||||
|
```
|
||||||
|
L_temporal = MSE(pred_next, x[:,1:])
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Data Processing / 数据处理
|
||||||
|
Defined in `example/data_utils.py` + `example/prepare_data.py`.
|
||||||
|
|
||||||
|
Key steps:
|
||||||
|
- Streaming mean/std/min/max + int-like detection
|
||||||
|
- Optional **log1p transform** for heavy-tailed continuous columns
|
||||||
|
- Discrete vocab + most frequent token
|
||||||
|
- Windowed batching with **shuffle buffer**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Sampling & Export / 采样与导出
|
||||||
|
Defined in:
|
||||||
|
- `example/sample.py`
|
||||||
|
- `example/export_samples.py`
|
||||||
|
|
||||||
|
Export process:
|
||||||
|
- Generate trend using temporal GRU
|
||||||
|
- Diffusion generates residuals
|
||||||
|
- Output: `trend + residual`
|
||||||
|
- De-normalize continuous values
|
||||||
|
- Clamp to observed min/max
|
||||||
|
- Restore discrete tokens from vocab
|
||||||
|
- Write to CSV
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Evaluation / 评估指标
|
||||||
|
Defined in `example/evaluate_generated.py`.
|
||||||
|
|
||||||
|
Metrics (with reference):
|
||||||
|
- **KS statistic** (continuous distribution)
|
||||||
|
- **Quantile diffs** (q05/q25/q50/q75/q95)
|
||||||
|
- **Lag‑1 correlation diff** (temporal structure)
|
||||||
|
- **Discrete JSD** over vocab frequency
|
||||||
|
- **Invalid token counts**
|
||||||
|
|
||||||
|
**指标汇总与对比脚本:** `example/summary_metrics.py`\n- 输出 avg_ks / avg_jsd / avg_lag1_diff\n- 追加记录到 `example/results/metrics_history.csv`\n- 如果存在上一次记录,输出 delta(新旧对比)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Automation / 自动化
|
||||||
|
`example/run_all.py` runs all stages with config-driven paths.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Key Engineering Decisions / 关键工程决策
|
||||||
|
- Mixed-type diffusion: continuous + discrete split
|
||||||
|
- Two-stage training: temporal backbone first, diffusion on residuals
|
||||||
|
- Positional + time embeddings for stability
|
||||||
|
- Optional inverse-variance weighting for continuous loss
|
||||||
|
- Log1p transforms for heavy-tailed signals
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Code Map (Key Files) / 代码索引
|
||||||
|
- Core model: `example/hybrid_diffusion.py`
|
||||||
|
- Training: `example/train.py`
|
||||||
|
- Temporal GRU: `example/hybrid_diffusion.py` (`TemporalGRUGenerator`)
|
||||||
|
- Data prep: `example/prepare_data.py`
|
||||||
|
- Data utilities: `example/data_utils.py`
|
||||||
|
- Sampling: `example/sample.py`
|
||||||
|
- Export: `example/export_samples.py`
|
||||||
|
- Evaluation: `example/evaluate_generated.py`
|
||||||
|
- Pipeline: `example/run_all.py`
|
||||||
|
- Config: `example/config.json`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 13. Known Issues / Current Limitations / 已知问题
|
||||||
|
- KS may remain high → continuous distribution mismatch
|
||||||
|
- Lag‑1 may fluctuate → distribution vs temporal trade-off
|
||||||
|
- Continuous loss may dominate → needs careful weighting
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 14. Suggested Next Steps / 下一步建议
|
||||||
|
- Add **SNR-weighted loss** for stable diffusion training
|
||||||
|
- Explore **v‑prediction** for continuous branch
|
||||||
|
- Strengthen discrete diffusion (e.g., D3PM-style transitions)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 15. Summary / 总结
|
||||||
|
This project implements a **two-stage hybrid diffusion model** for ICS feature sequences: a GRU-based temporal backbone first models sequence trends, then diffusion learns residual corrections. The pipeline covers data prep, two-stage training, sampling, export, and evaluation. The main research challenge remains in balancing **distributional fidelity (KS)** and **temporal consistency (lag‑1)**.
|
||||||
|
|
||||||
|
本项目实现了**两阶段混合扩散模型**:先用 GRU 时序骨干学习趋势,再用扩散学习残差校正。系统包含完整训练与评估流程。主要挑战仍是**分布对齐(KS)与时序一致性(lag‑1)之间的平衡**。
|
||||||
Reference in New Issue
Block a user