Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 08:55:58 +0000 (11:55 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 08:55:58 +0000 (11:55 +0300)
grids.py
main.py
quiz_machine.py

index 247c146..03da16d 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -581,34 +581,36 @@ class Grids(problem.Problem):
 
     def task_islands(self, A, f_A, B, f_B):
         for X, f_X in [(A, f_A), (B, f_B)]:
+            nb_on_border = 0
+            for _ in range(10):
+                for k in torch.randperm(self.height * self.width):
+                    i, j = k % self.height, k // self.height
+                    border = (
+                        i == 0 or i == self.height - 1 or j == 0 or j == self.width - 1
+                    )
+                    no, nq, nq_diag = self.contact(X, i, j, 1)
+
+                    if (
+                        (nq > 0 and not border)
+                        or (nq == 0 and border and nb_on_border < 4)
+                    ) and nq_diag == 0:
+                        X[i, j] = 1
+                        if border:
+                            nb_on_border += 1
+
             while True:
-                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
-                if (
-                    i == 0
-                    or i == self.height - 1
-                    or j == 0
-                    or j == self.width - 1
-                    or X[i, j] == 1
-                ):
-                    break
-            while True:
-                di, dj = torch.randint(3, (2,)) - 1
-                if abs(di) + abs(dj) > 0:
-                    break
-            X[i, j] = 1
-            while True:
-                i, j = i + di, j + dj
-                if i < 0 or i >= self.height or j < 0 or j >= self.width:
-                    break
-                b = (
-                    i == 0
-                    or i == self.height - 1
-                    or j == 0
-                    or j == self.width - 1
-                    or X[i, j] == 1
-                )
-                X[i, j] = 1
-                if b:
+                nb_fixes = 0
+                for i in range(1, self.height - 1):
+                    for j in range(1, self.width - 1):
+                        if (
+                            X[i, j] == 1
+                            and X[i - 1, j] + X[i + 1, j] + X[i, j - 1] + X[i, j + 1]
+                            == 1
+                        ):
+                            X[i, j] = 0
+                            nb_fixes += 1
+
+                if nb_fixes == 0:
                     break
 
     ######################################################################
@@ -681,8 +683,8 @@ if __name__ == "__main__":
 
     grids = Grids()
 
-    for t in grids.all_tasks():
-        # for t in [grids.task_islands]:
+    for t in grids.all_tasks():
+    for t in [grids.task_islands]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
         grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
diff --git a/main.py b/main.py
index bfbc0e4..6b46fa0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -418,6 +418,7 @@ def create_c_quizzes(
     )
 
     file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+
     with open(file_name, "w") as logp_file:
         while (
             valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
index de1e8d1..cb187be 100755 (executable)
@@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression(
             t_next = dist.sample()
 
         all_n = torch.arange(t_next.size(0))
-        seq_logproba += logits[all_n, t_next].sum(dim=-1)
+
+        seq_logproba += logits[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -420,11 +421,11 @@ class QuizMachine:
 
         nb_correct = 0
 
+        seq_logproba[...] = 0.0
+
         for model in models_for_validation:
             result = c_quizzes.clone()
 
-            seq_logproba[...] = 0.0
-
             ar_mask = self.make_ar_mask(result)
 
             masked_inplace_autoregression(