X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=7cb8d4f96ab03542205db29e333c8b22e176a1ae;hb=74311726e42dccb8bc096e86a7e9000576099bab;hp=967923644118d08fc52a4bb76056810cb49ce99e;hpb=cf7fcbb7a946c4d1f4d29a28e0eb04940d3b0f76;p=picoclvr.git diff --git a/main.py b/main.py index 9679236..7cb8d4f 100755 --- a/main.py +++ b/main.py @@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8) parser.add_argument("--snake_nb_colors", type=int, default=5) -parser.add_argument("--snake_length", type=int, default=400) +parser.add_argument("--snake_length", type=int, default=200) ###################################################################### @@ -143,8 +143,8 @@ default_args = { "batch_size": 25, }, "snake": { - "nb_epochs": 25, - "batch_size": 20, + "nb_epochs": 5, + "batch_size": 25, }, } @@ -689,7 +689,7 @@ class TaskSnake(Task): self.device = device self.prompt_length = prompt_length - self.train_input, self.train_prior_visits = snake.generate_sequences( + self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences( nb_train_samples, height, width, @@ -698,7 +698,7 @@ class TaskSnake(Task): prompt_length, self.device, ) - self.test_input, self.test_prior_visits = snake.generate_sequences( + self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences( nb_test_samples, height, width,