diff --git a/example/config.json b/example/config.json index d5feaf6..8fdfa9b 100644 --- a/example/config.json +++ b/example/config.json @@ -11,20 +11,20 @@ "seq_len": 128, "epochs": 8, "max_batches": 3000, - "lambda": 0.5, + "lambda": 0.8, "lr": 0.0005, "seed": 1337, "log_every": 10, "ckpt_every": 50, "ema_decay": 0.999, "use_ema": true, - "clip_k": 5.0, + "clip_k": 3.0, "grad_clip": 1.0, "use_condition": true, "condition_type": "file_id", "cond_dim": 32, "use_tanh_eps": true, - "eps_scale": 1.0, + "eps_scale": 0.7, "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/export_samples.py b/example/export_samples.py index 3bfed15..016ad71 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -234,7 +234,17 @@ def main(): out_cols = ["__cond_file_id"] + out_cols os.makedirs(os.path.dirname(args.out), exist_ok=True) - with open(args.out, "w", newline="", encoding="utf-8") as f: + out_path = args.out + try: + f = open(out_path, "w", newline="", encoding="utf-8") + except PermissionError: + # If file is locked (e.g. open in Excel), write to a new file + stem = Path(out_path).stem + suffix = Path(out_path).suffix or ".csv" + alt = Path(out_path).with_name(f"{stem}_new{suffix}") + out_path = str(alt) + f = open(out_path, "w", newline="", encoding="utf-8") + with f: writer = csv.DictWriter(f, fieldnames=out_cols) writer.writeheader() @@ -263,7 +273,7 @@ def main(): writer.writerow(row) row_index += 1 - print("exported_csv", args.out) + print("exported_csv", out_path) print("rows", args.batch_size * args.seq_len)