Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Dec 2022 20:06:44 +0000 (14:06 -0600)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Dec 2022 20:06:44 +0000 (14:06 -0600)
main.py
picoclvr.py

diff --git a/main.py b/main.py
index b6eb6fe..aa1b517 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -216,17 +216,11 @@ class TaskPicoCLVR(Task):
     def vocabulary_size(self):
         return len(self.token2id)
 
-    def produce_results(self, n_epoch, model):
+    def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
         nb_tokens_to_generate = self.height * self.width + 3
         result_descr = [ ]
-        nb_per_primer = 8
 
-        for primer_descr in [
-                'red above green <sep> green top <sep> blue right of red <img>',
-                'there is red <sep> there is yellow <sep> there is blue <img>',
-                'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
-                'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
-        ]:
+        for primer_descr in primers_descr:
 
             results = autoregression(
                 model,
@@ -249,18 +243,57 @@ class TaskPicoCLVR(Task):
 
         log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
 
-        img = [
-            picoclvr.descr2img(d, height = self.height, width = self.width)
-            for d in result_descr
+        np=torch.tensor(np)
+        count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+        for i in range(count.size(0)):
+            for j in range(count.size(1)):
+                count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+
+        if generate_images:
+            img = [
+                picoclvr.descr2img(d, height = self.height, width = self.width)
+                for d in result_descr
+            ]
+
+            img = torch.cat(img, 0)
+            image_name = f'result_picoclvr_{n_epoch:04d}.png'
+            torchvision.utils.save_image(
+                img / 255.,
+                image_name, nrow = nb_per_primer, pad_value = 0.8
+            )
+            log_string(f'wrote {image_name}')
+
+        return count
+
+    def produce_results(self, n_epoch, model):
+        primers_descr = [
+            'red above green <sep> green top <sep> blue right of red <img>',
+            'there is red <sep> there is yellow <sep> there is blue <img>',
+            'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+            'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
         ]
 
-        img = torch.cat(img, 0)
-        image_name = f'result_picoclvr_{n_epoch:04d}.png'
-        torchvision.utils.save_image(
-            img / 255.,
-            image_name, nrow = nb_per_primer, pad_value = 0.8
+        self.test_model(
+            n_epoch, model,
+            primers_descr,
+            nb_per_primer=8, generate_images=True
         )
-        log_string(f'wrote {image_name}')
+
+        # FAR TOO SLOW!!!
+
+        # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+        # count=self.test_model(
+            # n_epoch, model,
+            # test_primers_descr,
+            # nb_per_primer=1, generate_images=False
+        # )
+
+        # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+            # for i in range(count.size(0)):
+                # for j in range(count.size(1)):
+                    # f.write(f'{count[i,j]}')
+                    # f.write(" " if j<count.size(1)-1 else "\n")
 
 ######################################################################
 
index 3ecbf3a..059e352 100755 (executable)
@@ -95,7 +95,8 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
 
 def generate(nb, height, width,
              max_nb_squares = 5, max_nb_properties = 10,
-             nb_colors = 5):
+             nb_colors = 5,
+             pruning_criterion = None):
 
     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
 
@@ -117,6 +118,9 @@ def generate(nb, height, width,
 
         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
 
+        if pruning_criterion is not None:
+            s = list(filter(pruning_criterion,s))
+
         # pick at most max_nb_properties at random
 
         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
@@ -206,23 +210,26 @@ def nb_properties(descr, height, width):
 ######################################################################
 
 if __name__ == '__main__':
-    descr = generate(nb = 5)
+    descr = generate(
+        nb = 5, height = 12, width = 16,
+        pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+    )
 
-    #print(descr2properties(descr))
-    print(nb_properties(descr))
+    print(descr2properties(descr, height = 12, width = 16))
+    print(nb_properties(descr, height = 12, width = 16))
 
     with open('picoclvr_example.txt', 'w') as f:
         for d in descr:
             f.write(f'{d}\n\n')
 
-    img = descr2img(descr)
+    img = descr2img(descr, height = 12, width = 16)
     torchvision.utils.save_image(img / 255.,
                                  'picoclvr_example.png', nrow = 16, pad_value = 0.8)
 
     import time
 
     start_time = time.perf_counter()
-    descr = generate(nb = 1000)
+    descr = generate(nb = 1000, height = 12, width = 16)
     end_time = time.perf_counter()
     print(f'{len(descr) / (end_time - start_time):.02f} samples per second')