Update.
[picoclvr.git] / grid.py
diff --git a/grid.py b/grid.py
index 08ddc23..433cfd5 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -28,6 +28,7 @@ class GridFactory:
         self.height = height
         self.width = width
         self.max_nb_items = max_nb_items
+        self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
 
     def generate_scene(self):
@@ -44,8 +45,30 @@ class GridFactory:
             self.height, self.width
         )
 
-    def random_transformations(self):
+    def random_transformations(self, scene):
+        col, shp = scene
+        descriptions = []
         nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+        transformations = torch.randint(5, (nb_transformations,))
+
+        for t in transformations:
+            if t == 0:
+                col, shp = col.flip(0), shp.flip(0)
+                descriptions += ["<chg> vertical flip"]
+            elif t == 1:
+                col, shp = col.flip(1), shp.flip(1)
+                descriptions += ["<chg> horizontal flip"]
+            elif t == 2:
+                col, shp = col.flip(0).t(), shp.flip(0).t()
+                descriptions += ["<chg> rotate 90 degrees"]
+            elif t == 3:
+                col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+                descriptions += ["<chg> rotate 180 degrees"]
+            elif t == 4:
+                col, shp = col.flip(1).t(), shp.flip(1).t()
+                descriptions += ["<chg> rotate 270 degrees"]
+
+        return (col.contiguous(), shp.contiguous()), descriptions
 
     def print_scene(self, scene):
         col, shp = scene
@@ -118,7 +141,7 @@ class GridFactory:
 
         return properties
 
-    def generate_example(self):
+    def generate_scene_and_questions(self):
         while True:
             while True:
                 scene = self.generate_scene()
@@ -128,6 +151,8 @@ class GridFactory:
 
             start = self.grid_positions(scene)
 
+            scene, transformations = self.random_transformations(scene)
+
             for a in range(10):
                 col, shp = scene
                 col, shp = col.view(-1), shp.view(-1)
@@ -142,25 +167,53 @@ class GridFactory:
                 if len(false) >= self.nb_questions:
                     break
 
+            # print(f"{a=}")
+
             if a < 10:
                 break
 
         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
-        true = [(q, "yes") for q in true]
-        false = [(q, "no") for q in false]
+        true = ["<prop> " + q + " <true>" for q in true]
+        false = ["<prop> " + q + " <false>" for q in false]
 
         union = true + false
         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
 
-        return scene, questions
+        result = " ".join(
+            ["<obj> " + x for x in self.grid_positions(scene)]
+            + transformations
+            + questions
+        )
+
+        return scene, result
+
+    def generate_samples(self, nb, progress_bar=None):
+        result = []
+
+        r = range(nb)
+        if progress_bar is not None:
+            r = progress_bar(r)
+
+        for _ in r:
+            result.append(self.generate_scene_and_questions()[1])
+
+        return result
 
 
 ######################################################################
 
 if __name__ == "__main__":
+    import time
+
     grid_factory = GridFactory()
-    scene, questions = grid_factory.generate_example()
+
+    start_time = time.perf_counter()
+    samples = grid_factory.generate_samples(10000)
+    end_time = time.perf_counter()
+    print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+
+    scene, questions = grid_factory.generate_scene_and_questions()
     grid_factory.print_scene(scene)
     print(questions)