From 9047bd8185ed99c1302d8812551af3d5bd4602cb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 11:04:52 +0200 Subject: [PATCH] Update. --- main.py | 16 ++++++++++++++-- tasks.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 97c7130..b57c512 100755 --- a/main.py +++ b/main.py @@ -73,6 +73,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--check", action="store_true", default=False) + ###################################################################### args = parser.parse_args() @@ -183,6 +185,9 @@ for n in vars(args): ###################################################################### +if args.test: + args.nb_train_samples = 1000 + args.nb_test_samples = 25 if args.physical_batch_size is None: args.physical_batch_size = args.batch_size @@ -606,6 +611,13 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### accuracy_to_make_quizzes = 0.975 +nb_new_quizzes_for_train = 1000 +nb_new_quizzes_for_test = 100 + +if args.test: + accuracy_to_make_quizzes = 0.0 + nb_new_quizzes_for_train = 10 + nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): # select the model with lowest accuracy @@ -634,8 +646,8 @@ for n_epoch in range(args.nb_epochs): model, other_models, task, - nb_for_train=1000, - nb_for_test=100, + nb_for_train=nb_new_quizzes_for_train, + nb_for_test=nb_new_quizzes_for_test, ) diff --git a/tasks.py b/tasks.py index 77493a8..50d541b 100755 --- a/tasks.py +++ b/tasks.py @@ -273,6 +273,34 @@ class World(Task): device=self.device, ) - nb_correct += (new_quizzes == result).long().min(dim=-1).values + l = self.height * self.width + direction = new_quizzes[:, l : l + 1] + direction = world.token_forward * ( + direction == world.token_backward + ) + world.token_backward * (direction == world.token_forward) + inverted_quizzes = torch.cat( + [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1 + ) + + inverted_result = inverted_quizzes.clone() + + masked_inplace_autoregression( + m, + self.batch_size, + inverted_result, + ar_mask, + deterministic_synthesis=True, + progress_bar_desc="solving reverse quizzes", + device=self.device, + ) + + nb_correct += ( + ( + (new_quizzes == result).long() + * (inverted_quizzes, inverted_result).long() + ) + .min(dim=-1) + .values + ) return new_quizzes, nb_correct -- 2.39.5