This commit is contained in:
2026-01-22 20:42:10 +08:00
parent f37a8ce179
commit 382c756dfe
10 changed files with 310 additions and 55 deletions

View File

@@ -26,6 +26,7 @@ REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = {
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
"split_path": BASE_DIR / "feature_split.json",
"stats_path": BASE_DIR / "results" / "cont_stats.json",
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
@@ -41,6 +42,15 @@ DEFAULTS = {
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": True,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": True,
"eps_scale": 1.0,
}
@@ -69,7 +79,7 @@ def parse_args():
def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys:
if key in config:
# 如果值是字符串转换为Path对象
@@ -85,6 +95,26 @@ def resolve_config_paths(config, base_dir: Path):
return config
class EMA:
def __init__(self, model, decay: float):
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.detach().clone()
def update(self, model):
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old = self.shadow[name]
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
def state_dict(self):
return self.shadow
def main():
args = parse_args()
config = dict(DEFAULTS)
@@ -113,25 +143,47 @@ def main():
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
log_path = os.path.join(config["out_dir"], "train_log.csv")
out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv")
with open(log_path, "w", encoding="utf-8") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
total_step = 0
for epoch in range(int(config["epochs"])):
for step, (x_cont, x_disc) in enumerate(
for step, batch in enumerate(
windowed_batches(
config["data_path"],
data_paths,
cont_cols,
disc_cols,
vocab,
@@ -140,8 +192,15 @@ def main():
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
)
):
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
@@ -153,21 +212,29 @@ def main():
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
eps_pred, logits = model(x_cont_t, x_disc_t, t)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
)
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
@@ -185,9 +252,13 @@ def main():
"config": config,
"step": total_step,
}
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
if ema is not None:
ckpt["ema"] = ema.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if __name__ == "__main__":