X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=2ed6b6b56e6e3fec57750e17de4ee8bc0cb7967b;hb=dbb361548096536b62f50d810885439043ec08a3;hp=8c4b7a1b94e51ff54ed4b3fbdfc3494a152eaf05;hpb=abebc8df53908d9f395ae2d9e20d8b00fd50ae4e;p=picoclvr.git diff --git a/main.py b/main.py index 8c4b7a1..2ed6b6b 100755 --- a/main.py +++ b/main.py @@ -37,7 +37,7 @@ parser.add_argument( parser.add_argument("--log_filename", type=str, default="train.log", help=" ") -parser.add_argument("--result_dir", type=str, default="results_default") +parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) @@ -144,30 +144,35 @@ if args.seed >= 0: default_args = { "picoclvr": { + "result_dir": "results_picoclvr", "nb_epochs": 25, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { + "result_dir": "results_mnist", "nb_epochs": 25, "batch_size": 10, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "maze": { + "result_dir": "results_maze", "nb_epochs": 25, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "snake": { + "result_dir": "results_snake", "nb_epochs": 5, "batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "stack": { + "result_dir": "results_stack", "nb_epochs": 5, "batch_size": 25, "nb_train_samples": 100000, @@ -968,8 +973,8 @@ class TaskStack(Task): ) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - l=50 - l=l-l%(1+self.nb_digits) + l = 50 + l = l - l % (1 + self.nb_digits) input = self.test_input[:10, :l] result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)