Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 07:42:55 +0000 (09:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 07:42:55 +0000 (09:42 +0200)
grids.py
main.py

index ae3a71e..d41ec49 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -1520,6 +1520,40 @@ class Grids(problem.Problem):
                 if q >= 2:
                     break
 
+    def collide(self, s, r, rs):
+        i, j = r
+        for i2, j2 in rs:
+            if abs(i - i2) < s and abs(j - j2) < s:
+                return True
+        return False
+
+    def task_science_tag(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            rs = []
+            while len(rs) < 4:
+                i, j = (
+                    torch.randint(self.height - 3, (1,)).item(),
+                    torch.randint(self.width - 3, (1,)).item(),
+                )
+                if not self.collide(s=3, r=(i, j), rs=rs):
+                    rs.append((i, j))
+
+            for k in range(len(rs)):
+                i, j = rs[k]
+                q = min(k, 2)
+                X[i, j : j + 3] = c[q]
+                X[i + 2, j : j + 3] = c[q]
+                X[i : i + 3, j] = c[q]
+                X[i : i + 3, j + 2] = c[q]
+
+                f_X[i, j : j + 3] = c[q]
+                f_X[i + 2, j : j + 3] = c[q]
+                f_X[i : i + 3, j] = c[q]
+                f_X[i : i + 3, j + 2] = c[q]
+                if q == 2:
+                    f_X[i + 1, j + 1] = c[-1]
+
     # end_tasks
 
     ######################################################################
@@ -1618,7 +1652,7 @@ if __name__ == "__main__":
 
     # for t in grids.all_tasks:
 
-    for t in [grids.task_science_dot]:
+    for t in [grids.task_science_tag]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
diff --git a/main.py b/main.py
index 1b68eca..bf617b5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -552,7 +552,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         rejected = []
 
-        to_keep == quiz_machine.problem.trivial(c_quizzes) == False
+        to_keep = quiz_machine.problem.trivial(c_quizzes) == False
 
         if not to_keep.all():
             rejected.append(c_quizzes[to_keep == False])