连续型特征在时许相关性上的不足
This commit is contained in:
@@ -48,6 +48,8 @@ def main():
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
data_glob = cfg.get("data_glob", "")
|
||||
data_path = cfg.get("data_path", "")
|
||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||
run(
|
||||
@@ -70,7 +72,11 @@ def main():
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
ref = data_glob if data_glob else data_path
|
||||
if ref:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||
else:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
run([sys.executable, str(base_dir / "plot_loss.py")])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user