update ks

This commit is contained in:
2026-01-25 18:13:37 +08:00
parent b3c45010a4
commit bc838d7cd7
3 changed files with 47 additions and 6 deletions

View File

@@ -39,6 +39,9 @@
"cont_clamp_x0": 5.0, "cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1, "quantile_loss_weight": 0.1,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128

View File

@@ -34,6 +34,7 @@ def main():
loss = [] loss = []
loss_cont = [] loss_cont = []
loss_disc = [] loss_disc = []
loss_quant = []
with log_path.open("r", encoding="utf-8", newline="") as f: with log_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
@@ -42,6 +43,8 @@ def main():
loss.append(float(row["loss"])) loss.append(float(row["loss"]))
loss_cont.append(float(row["loss_cont"])) loss_cont.append(float(row["loss_cont"]))
loss_disc.append(float(row["loss_disc"])) loss_disc.append(float(row["loss_disc"]))
if "loss_quantile" in row:
loss_quant.append(float(row["loss_quantile"]))
if not steps: if not steps:
raise SystemExit("no rows in log file: %s" % log_path) raise SystemExit("no rows in log file: %s" % log_path)
@@ -50,6 +53,8 @@ def main():
plt.plot(steps, loss, label="total") plt.plot(steps, loss, label="total")
plt.plot(steps, loss_cont, label="continuous") plt.plot(steps, loss_cont, label="continuous")
plt.plot(steps, loss_disc, label="discrete") plt.plot(steps, loss_disc, label="discrete")
if loss_quant:
plt.plot(steps, loss_quant, label="quantile")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel("loss") plt.ylabel("loss")
plt.title("Training Loss") plt.title("Training Loss")

View File

@@ -66,6 +66,9 @@ DEFAULTS = {
"cont_clamp_x0": 0.0, "cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0, "quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
} }
@@ -205,7 +208,11 @@ def main():
os.makedirs(config["out_dir"], exist_ok=True) os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"]) out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv") log_path = os.path.join(out_dir, "train_log.csv")
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
with open(log_path, "w", encoding="utf-8") as f: with open(log_path, "w", encoding="utf-8") as f:
if use_quantile:
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
else:
f.write("epoch,step,loss,loss_cont,loss_disc\n") f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
@@ -290,7 +297,11 @@ def main():
loss = lam * loss_cont + (1 - lam) * loss_disc loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0)) q_weight = float(config.get("quantile_loss_weight", 0.0))
quantile_loss = 0.0
if q_weight > 0: if q_weight > 0:
warmup = int(config.get("quantile_loss_warmup_steps", 0))
if warmup > 0:
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0. # Use normalized space for stable quantiles on x0.
@@ -304,11 +315,20 @@ def main():
else: else:
# eps prediction # eps prediction
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
q_clip = float(config.get("quantile_loss_clip", 0.0))
if q_clip > 0:
x_real = torch.clamp(x_real, -q_clip, q_clip)
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
x_real = x_real.view(-1, x_real.size(-1)) x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1)) x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0) q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0) q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real)) q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
q_diff = q_gen - q_real
if q_delta > 0:
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss loss = loss + q_weight * quantile_loss
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
@@ -321,6 +341,19 @@ def main():
if step % int(config["log_every"]) == 0: if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss)) print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f: with open(log_path, "a", encoding="utf-8") as f:
if use_quantile:
f.write(
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
% (
epoch,
step,
float(loss),
float(loss_cont),
float(loss_disc),
float(quantile_loss),
)
)
else:
f.write( f.write(
"%d,%d,%.6f,%.6f,%.6f\n" "%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc)) % (epoch, step, float(loss), float(loss_cont), float(loss_disc))