Compare commits

3 Commits

Author SHA1 Message Date
MZ YANG
bdfc6e2aaa update 2026-01-25 17:35:48 +08:00
MZ YANG
666bc3b8a9 update 2026-01-25 17:28:24 +08:00
a870945e33 update ks 2026-01-25 17:16:57 +08:00
6 changed files with 1682 additions and 1653 deletions

View File

@@ -37,6 +37,8 @@
"cont_loss_eps": 1e-6,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"shuffle_buffer": 256,
"sample_batch_size": 8,
"sample_seq_len": 128

View File

@@ -38,6 +38,14 @@
"cont_loss_eps": 1e-06,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1,
"quantile_points": [
0.05,
0.25,
0.5,
0.75,
0.95
],
"sample_batch_size": 8,
"sample_seq_len": 128
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,61 @@
epoch,step,loss,loss_cont,loss_disc
0,0,9801.450195,14001.333008,1.724242
0,10,8648.388672,12354.390625,1.050415
0,20,6285.666992,8979.191406,0.775979
0,30,6296.939941,8995.372070,0.598658
0,40,7126.128906,10180.011719,0.402291
0,50,5381.071289,7687.099121,0.340463
1,0,17876.904297,25538.212891,0.522857
1,10,9174.909180,13106.868164,0.338315
1,20,7635.462891,10907.713867,0.211883
1,30,5425.212402,7750.165039,0.323606
1,40,4372.716309,6246.610840,0.296102
1,50,3846.437988,5494.793945,0.273941
2,0,17057.958984,24368.269531,0.564671
2,10,7089.009766,10127.021484,0.315390
2,20,4230.856445,6043.994141,0.201429
2,30,3744.107910,5348.593262,0.309118
2,40,3531.041992,5044.219238,0.295744
2,50,3570.459229,5100.528320,0.297530
3,0,14367.601562,20524.917969,0.529623
3,10,6734.334473,9620.347656,0.304721
3,20,6140.179688,8771.602539,0.193836
3,30,4089.454102,5841.940918,0.317864
3,40,3553.830811,5076.785645,0.269531
3,50,3590.448242,5129.088867,0.287063
4,0,14410.816406,20586.648438,0.543822
4,10,6411.443359,9159.058594,0.341742
4,20,3816.795166,5452.479492,0.198213
4,30,4069.170898,5812.959961,0.329676
4,40,3484.921631,4978.332520,0.296284
4,50,2802.801514,4003.873779,0.299286
5,0,13335.293945,19050.201172,0.509557
5,10,5531.156738,7901.527344,0.293409
5,20,3844.260010,5491.696777,0.241263
5,30,3619.303223,5170.297363,0.317237
5,40,3492.172852,4988.697754,0.281641
5,50,3069.457275,4384.815918,0.287269
6,0,8740.982422,12486.912109,0.483061
6,10,6110.571777,8729.239258,0.347929
6,20,3350.194092,4785.889160,0.239650
6,30,3008.237549,4297.353516,0.300327
6,40,2944.483887,4206.288574,0.273457
6,50,3018.033447,4311.365234,0.259749
7,0,6784.341309,9691.704102,0.495378
7,10,4946.872559,7066.822754,0.321541
7,20,2816.704102,4023.745361,0.274515
7,30,2991.350830,4273.229980,0.299238
7,40,3023.450684,4319.083984,0.307114
7,50,2944.912598,4206.909668,0.253002
8,0,7364.851562,10521.002930,0.498283
8,10,3897.301025,5567.439453,0.311517
8,20,3313.474854,4733.448242,0.203364
8,30,2697.139648,3852.940430,0.271355
8,40,2955.225342,4221.633301,0.273175
8,50,2932.081787,4188.575684,0.262651
9,0,4065.651611,5807.854004,0.513017
9,10,4358.108398,6225.728516,0.329111
9,20,3417.019043,4881.362305,0.218488
9,30,2917.136719,4167.226562,0.260805
9,40,2786.277832,3980.287109,0.256345
9,50,2602.660645,3717.971680,0.268723
0,0,9801.515625,14001.333008,1.724242
0,10,8648.471680,12354.386719,1.050416
0,20,6285.875000,8979.247070,0.775979
0,30,6297.153809,8995.444336,0.598639
0,40,7421.842285,10602.228516,0.402729
0,50,8768.145508,12525.551758,0.339361
1,0,17362.210938,24802.792969,0.521414
1,10,9998.125000,14282.763672,0.346019
1,20,6025.115234,8606.978516,0.211357
1,30,5529.321777,7898.669922,0.321453
1,40,4561.373047,6515.895508,0.306940
1,50,4133.944336,5905.306641,0.271862
2,0,18390.314453,26271.496094,0.559547
2,10,5629.449707,8041.824219,0.315514
2,20,5088.290527,7268.675293,0.209568
2,30,4109.000000,5869.645020,0.316234
2,40,4135.112305,5906.964355,0.288305
2,50,3574.877686,5106.638672,0.300608
3,0,16117.513672,23024.658203,0.541584
3,10,7299.515137,10427.646484,0.307387
3,20,3928.072998,5611.221680,0.202967
3,30,4386.781250,6266.497559,0.325895
3,40,3627.836182,5182.310059,0.271281
3,50,3513.194580,5018.536133,0.291253
4,0,14937.586914,21339.070312,0.552027
4,10,6087.895508,8696.762695,0.334634
4,20,3961.117676,5658.443359,0.210930
4,30,3405.418457,4864.548828,0.317895
4,40,3483.671631,4976.360840,0.296246
4,50,2833.118164,4047.002441,0.297401
5,0,12412.599609,17731.978516,0.492310
5,10,4952.285156,7074.484375,0.294987
5,20,4023.841309,5748.041504,0.219470
5,30,3416.583740,4880.517090,0.310916
5,40,3283.848389,4690.912598,0.274809
5,50,2953.851807,4219.480469,0.286249
6,0,9772.470703,13960.334961,0.525315
6,10,4856.467773,6937.604980,0.294435
6,20,3487.783203,4982.249023,0.249389
6,30,2907.010498,4152.563965,0.302258
6,40,2978.796875,4255.132324,0.272516
6,50,2954.490723,4220.402832,0.260484
7,0,5634.914062,8049.578125,0.477380
7,10,4834.394531,6906.059570,0.314162
7,20,2799.942871,3999.613770,0.267521
7,30,2899.989990,4142.533203,0.294986
7,40,2961.559570,4230.455078,0.337903
7,50,3053.434814,4361.746094,0.263959
8,0,5015.993652,7165.385742,0.495864
8,10,3965.379639,5664.615234,0.316581
8,20,3669.429688,5241.729980,0.204997
8,30,2815.938232,4022.476562,0.271451
8,40,2967.452881,4238.926758,0.263370
8,50,2930.122314,4185.593262,0.262396
9,0,4364.022461,6233.995605,0.496020
9,10,4222.906250,6032.508301,0.319415
9,20,3070.776367,4386.530762,0.233229
9,30,2839.424805,4056.029785,0.265954
9,40,2770.363770,3957.375000,0.264960
9,50,2557.437256,3653.188721,0.265549
1 epoch step loss loss_cont loss_disc
2 0 0 9801.450195 9801.515625 14001.333008 1.724242
3 0 10 8648.388672 8648.471680 12354.390625 12354.386719 1.050415 1.050416
4 0 20 6285.666992 6285.875000 8979.191406 8979.247070 0.775979
5 0 30 6296.939941 6297.153809 8995.372070 8995.444336 0.598658 0.598639
6 0 40 7126.128906 7421.842285 10180.011719 10602.228516 0.402291 0.402729
7 0 50 5381.071289 8768.145508 7687.099121 12525.551758 0.340463 0.339361
8 1 0 17876.904297 17362.210938 25538.212891 24802.792969 0.522857 0.521414
9 1 10 9174.909180 9998.125000 13106.868164 14282.763672 0.338315 0.346019
10 1 20 7635.462891 6025.115234 10907.713867 8606.978516 0.211883 0.211357
11 1 30 5425.212402 5529.321777 7750.165039 7898.669922 0.323606 0.321453
12 1 40 4372.716309 4561.373047 6246.610840 6515.895508 0.296102 0.306940
13 1 50 3846.437988 4133.944336 5494.793945 5905.306641 0.273941 0.271862
14 2 0 17057.958984 18390.314453 24368.269531 26271.496094 0.564671 0.559547
15 2 10 7089.009766 5629.449707 10127.021484 8041.824219 0.315390 0.315514
16 2 20 4230.856445 5088.290527 6043.994141 7268.675293 0.201429 0.209568
17 2 30 3744.107910 4109.000000 5348.593262 5869.645020 0.309118 0.316234
18 2 40 3531.041992 4135.112305 5044.219238 5906.964355 0.295744 0.288305
19 2 50 3570.459229 3574.877686 5100.528320 5106.638672 0.297530 0.300608
20 3 0 14367.601562 16117.513672 20524.917969 23024.658203 0.529623 0.541584
21 3 10 6734.334473 7299.515137 9620.347656 10427.646484 0.304721 0.307387
22 3 20 6140.179688 3928.072998 8771.602539 5611.221680 0.193836 0.202967
23 3 30 4089.454102 4386.781250 5841.940918 6266.497559 0.317864 0.325895
24 3 40 3553.830811 3627.836182 5076.785645 5182.310059 0.269531 0.271281
25 3 50 3590.448242 3513.194580 5129.088867 5018.536133 0.287063 0.291253
26 4 0 14410.816406 14937.586914 20586.648438 21339.070312 0.543822 0.552027
27 4 10 6411.443359 6087.895508 9159.058594 8696.762695 0.341742 0.334634
28 4 20 3816.795166 3961.117676 5452.479492 5658.443359 0.198213 0.210930
29 4 30 4069.170898 3405.418457 5812.959961 4864.548828 0.329676 0.317895
30 4 40 3484.921631 3483.671631 4978.332520 4976.360840 0.296284 0.296246
31 4 50 2802.801514 2833.118164 4003.873779 4047.002441 0.299286 0.297401
32 5 0 13335.293945 12412.599609 19050.201172 17731.978516 0.509557 0.492310
33 5 10 5531.156738 4952.285156 7901.527344 7074.484375 0.293409 0.294987
34 5 20 3844.260010 4023.841309 5491.696777 5748.041504 0.241263 0.219470
35 5 30 3619.303223 3416.583740 5170.297363 4880.517090 0.317237 0.310916
36 5 40 3492.172852 3283.848389 4988.697754 4690.912598 0.281641 0.274809
37 5 50 3069.457275 2953.851807 4384.815918 4219.480469 0.287269 0.286249
38 6 0 8740.982422 9772.470703 12486.912109 13960.334961 0.483061 0.525315
39 6 10 6110.571777 4856.467773 8729.239258 6937.604980 0.347929 0.294435
40 6 20 3350.194092 3487.783203 4785.889160 4982.249023 0.239650 0.249389
41 6 30 3008.237549 2907.010498 4297.353516 4152.563965 0.300327 0.302258
42 6 40 2944.483887 2978.796875 4206.288574 4255.132324 0.273457 0.272516
43 6 50 3018.033447 2954.490723 4311.365234 4220.402832 0.259749 0.260484
44 7 0 6784.341309 5634.914062 9691.704102 8049.578125 0.495378 0.477380
45 7 10 4946.872559 4834.394531 7066.822754 6906.059570 0.321541 0.314162
46 7 20 2816.704102 2799.942871 4023.745361 3999.613770 0.274515 0.267521
47 7 30 2991.350830 2899.989990 4273.229980 4142.533203 0.299238 0.294986
48 7 40 3023.450684 2961.559570 4319.083984 4230.455078 0.307114 0.337903
49 7 50 2944.912598 3053.434814 4206.909668 4361.746094 0.253002 0.263959
50 8 0 7364.851562 5015.993652 10521.002930 7165.385742 0.498283 0.495864
51 8 10 3897.301025 3965.379639 5567.439453 5664.615234 0.311517 0.316581
52 8 20 3313.474854 3669.429688 4733.448242 5241.729980 0.203364 0.204997
53 8 30 2697.139648 2815.938232 3852.940430 4022.476562 0.271355 0.271451
54 8 40 2955.225342 2967.452881 4221.633301 4238.926758 0.273175 0.263370
55 8 50 2932.081787 2930.122314 4188.575684 4185.593262 0.262651 0.262396
56 9 0 4065.651611 4364.022461 5807.854004 6233.995605 0.513017 0.496020
57 9 10 4358.108398 4222.906250 6225.728516 6032.508301 0.329111 0.319415
58 9 20 3417.019043 3070.776367 4881.362305 4386.530762 0.218488 0.233229
59 9 30 2917.136719 2839.424805 4167.226562 4056.029785 0.260805 0.265954
60 9 40 2786.277832 2770.363770 3980.287109 3957.375000 0.256345 0.264960
61 9 50 2602.660645 2557.437256 3717.971680 3653.188721 0.268723 0.265549

View File

@@ -64,6 +64,8 @@ DEFAULTS = {
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
}
@@ -282,6 +284,23 @@ def main():
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
if q_weight > 0:
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles.
x_real = x_cont
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = x_cont - noise
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
loss = loss + q_weight * quantile_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0: