update
This commit is contained in:
@@ -11,20 +11,20 @@
|
|||||||
"seq_len": 128,
|
"seq_len": 128,
|
||||||
"epochs": 8,
|
"epochs": 8,
|
||||||
"max_batches": 3000,
|
"max_batches": 3000,
|
||||||
"lambda": 0.5,
|
"lambda": 0.8,
|
||||||
"lr": 0.0005,
|
"lr": 0.0005,
|
||||||
"seed": 1337,
|
"seed": 1337,
|
||||||
"log_every": 10,
|
"log_every": 10,
|
||||||
"ckpt_every": 50,
|
"ckpt_every": 50,
|
||||||
"ema_decay": 0.999,
|
"ema_decay": 0.999,
|
||||||
"use_ema": true,
|
"use_ema": true,
|
||||||
"clip_k": 5.0,
|
"clip_k": 3.0,
|
||||||
"grad_clip": 1.0,
|
"grad_clip": 1.0,
|
||||||
"use_condition": true,
|
"use_condition": true,
|
||||||
"condition_type": "file_id",
|
"condition_type": "file_id",
|
||||||
"cond_dim": 32,
|
"cond_dim": 32,
|
||||||
"use_tanh_eps": true,
|
"use_tanh_eps": true,
|
||||||
"eps_scale": 1.0,
|
"eps_scale": 0.7,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -234,7 +234,17 @@ def main():
|
|||||||
out_cols = ["__cond_file_id"] + out_cols
|
out_cols = ["__cond_file_id"] + out_cols
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||||
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
out_path = args.out
|
||||||
|
try:
|
||||||
|
f = open(out_path, "w", newline="", encoding="utf-8")
|
||||||
|
except PermissionError:
|
||||||
|
# If file is locked (e.g. open in Excel), write to a new file
|
||||||
|
stem = Path(out_path).stem
|
||||||
|
suffix = Path(out_path).suffix or ".csv"
|
||||||
|
alt = Path(out_path).with_name(f"{stem}_new{suffix}")
|
||||||
|
out_path = str(alt)
|
||||||
|
f = open(out_path, "w", newline="", encoding="utf-8")
|
||||||
|
with f:
|
||||||
writer = csv.DictWriter(f, fieldnames=out_cols)
|
writer = csv.DictWriter(f, fieldnames=out_cols)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
@@ -263,7 +273,7 @@ def main():
|
|||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
row_index += 1
|
row_index += 1
|
||||||
|
|
||||||
print("exported_csv", args.out)
|
print("exported_csv", out_path)
|
||||||
print("rows", args.batch_size * args.seq_len)
|
print("rows", args.batch_size * args.seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user