From 36a1440d01cc15643849f5ba421f89ac403ccd82 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 19:24:50 +0200 Subject: [PATCH] Update. --- README.txt | 22 +++++++++++++--------- main.py | 4 ++-- tasks.py | 10 +++++----- world.py | 2 +- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/README.txt b/README.txt index af96ee9..1bd8dbc 100644 --- a/README.txt +++ b/README.txt @@ -20,22 +20,25 @@ be solved but not by everybody. There are 5 competing GPTs. -The "world" is a 6x8 grid with one or two "birds" moving in a straight -line and bouncing on the world's borders. The colors correspond to a -fixed "z-buffer order". It could be another "world", but this one has +The "world" is a 7x9 grid with three "birds" moving in a straight line +and bouncing on the world's borders. The colors correspond to a fixed +"z-buffer order". It could be another "world", but this one has objectness, occlusion, and motion. Given a random world state, and the state after two iterations of birds moving, a "quiz" is to predict the second frame, given the first, or the opposite. -My home-baked GPT-37M trained with 250k solves this with ~99% success. +My home-baked GPT-37M trained with 250k solves this with ~99% success +[to be verified with the new setup]. At every iteration, we select the GPT with the lowest test accuracy, -and run one epoch. If its test accuracy got higher than 97.5%, it will -create new quizzes. To do so, it generates a large number of pairs of -frames, and checks which ones of these quizzes are hard but not too -hard, which means +and run one epoch. + +If its test accuracy got higher than 97.5%, it will create new +quizzes. To do so, it generates a large number of pairs of frames, and +checks which ones of these quizzes are hard but not too hard, which +means [THIS IS THE IMPORTANT BIT]: @@ -48,7 +51,8 @@ simply to deal with noise in the first frame. The GPT generates 1000 of such quizzes, that are added to the "culture", i.e. the training set. -Then training resumes. +We update the test accuracy of all the GPTs, and then we go to the +next iteration. The hope is that interesting concepts emerge (connectivity, symmetry, interior/exterior, shape vocabulary, etc.) diff --git a/main.py b/main.py index 09ae823..b6f2783 100755 --- a/main.py +++ b/main.py @@ -360,7 +360,7 @@ def create_quizzes( task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False) task.save_image( - new_quizzes[:96], + new_quizzes[:72], args.result_dir, f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", log_string, @@ -404,7 +404,7 @@ if args.check: nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): - a = [(model.id, model.main_test_accuracy) for model in models] + a = [(model.id, model.main_test_accuracy.item()) for model in models] a.sort(key=lambda p: p[0]) log_string(f"current accuracies {a}") diff --git a/tasks.py b/tasks.py index 9a67127..ad95237 100755 --- a/tasks.py +++ b/tasks.py @@ -81,7 +81,7 @@ class World(Task): def save_image(self, input, result_dir, filename, logger): img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") def make_ar_mask(self, input): @@ -101,8 +101,8 @@ class World(Task): self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 + self.height = 7 + self.width = 9 self.train_input = world.generate( nb_train_samples, height=self.height, width=self.width @@ -126,7 +126,7 @@ class World(Task): if result_dir is not None: self.save_image( - self.train_input[:96], result_dir, f"world_train.png", logger + self.train_input[:72], result_dir, f"world_train.png", logger ) def batches(self, split="train", desc=None): @@ -222,7 +222,7 @@ class World(Task): ) self.save_image( - result[:96], + result[:72], result_dir, f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, diff --git a/world.py b/world.py index 4055533..05d7505 100755 --- a/world.py +++ b/world.py @@ -59,7 +59,7 @@ def generate( nb, height, width, - nb_birds=2, + nb_birds=3, nb_iterations=2, ): pairs = [] -- 2.39.5