From dbb361548096536b62f50d810885439043ec08a3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 3 Jul 2023 10:44:05 +0200 Subject: [PATCH] Update. --- README.txt | 8 ++------ main.py | 7 ++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.txt b/README.txt index b5c4109..223565e 100644 --- a/README.txt +++ b/README.txt @@ -1,9 +1,5 @@ -For the stack experiments: - -./main.py --task=stack - -./main.py --task=stack --stack_fraction_values_for_train=0.75 +For the stack experiment: ./main.py --task=stack --stack_fraction_values_for_train=0.75 --stack_nb_stacks=3 -Each takes ~1h10min on a 4090. +Takes ~1h10min on a 4090. diff --git a/main.py b/main.py index 38dccb9..2ed6b6b 100755 --- a/main.py +++ b/main.py @@ -37,7 +37,7 @@ parser.add_argument( parser.add_argument("--log_filename", type=str, default="train.log", help=" ") -parser.add_argument("--result_dir", type=str, default="results_default") +parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) @@ -144,30 +144,35 @@ if args.seed >= 0: default_args = { "picoclvr": { + "result_dir": "results_picoclvr", "nb_epochs": 25, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { + "result_dir": "results_mnist", "nb_epochs": 25, "batch_size": 10, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "maze": { + "result_dir": "results_maze", "nb_epochs": 25, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "snake": { + "result_dir": "results_snake", "nb_epochs": 5, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "stack": { + "result_dir": "results_stack", "nb_epochs": 5, "batch_size": 25, "nb_train_samples": 100000, -- 2.20.1