update ks
This commit is contained in:
@@ -66,6 +66,9 @@ DEFAULTS = {
|
||||
"cont_clamp_x0": 0.0,
|
||||
"quantile_loss_weight": 0.0,
|
||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||
"quantile_loss_warmup_steps": 200,
|
||||
"quantile_loss_clip": 6.0,
|
||||
"quantile_loss_huber_delta": 1.0,
|
||||
}
|
||||
|
||||
|
||||
@@ -205,8 +208,12 @@ def main():
|
||||
os.makedirs(config["out_dir"], exist_ok=True)
|
||||
out_dir = safe_path(config["out_dir"])
|
||||
log_path = os.path.join(out_dir, "train_log.csv")
|
||||
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
|
||||
with open(log_path, "w", encoding="utf-8") as f:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
if use_quantile:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
|
||||
else:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
@@ -290,7 +297,11 @@ def main():
|
||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||
|
||||
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||||
quantile_loss = 0.0
|
||||
if q_weight > 0:
|
||||
warmup = int(config.get("quantile_loss_warmup_steps", 0))
|
||||
if warmup > 0:
|
||||
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
|
||||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||
# Use normalized space for stable quantiles on x0.
|
||||
@@ -304,11 +315,20 @@ def main():
|
||||
else:
|
||||
# eps prediction
|
||||
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||
q_clip = float(config.get("quantile_loss_clip", 0.0))
|
||||
if q_clip > 0:
|
||||
x_real = torch.clamp(x_real, -q_clip, q_clip)
|
||||
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
|
||||
x_real = x_real.view(-1, x_real.size(-1))
|
||||
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
||||
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
|
||||
q_diff = q_gen - q_real
|
||||
if q_delta > 0:
|
||||
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
|
||||
else:
|
||||
quantile_loss = torch.mean(torch.abs(q_diff))
|
||||
loss = loss + q_weight * quantile_loss
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
@@ -321,10 +341,23 @@ def main():
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||
)
|
||||
if use_quantile:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
|
||||
% (
|
||||
epoch,
|
||||
step,
|
||||
float(loss),
|
||||
float(loss_cont),
|
||||
float(loss_disc),
|
||||
float(quantile_loss),
|
||||
)
|
||||
)
|
||||
else:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||
)
|
||||
|
||||
total_step += 1
|
||||
if total_step % int(config["ckpt_every"]) == 0:
|
||||
|
||||
Reference in New Issue
Block a user