Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 15:39:09 +0000 (17:39 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 15:39:09 +0000 (17:39 +0200)
grids.py
main.py

index 05c3057..f195144 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -257,6 +257,9 @@ class Grids(problem.Problem):
         self.token_B = self.token_f_A + 1
         self.token_f_B = self.token_B + 1
 
+        self.nb_rec_max = 5
+        self.rfree = torch.tensor([])
+
         self.l2tok = {
             "A": self.token_A,
             "f_A": self.token_f_A,
@@ -574,15 +577,39 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def contact_matrices(self, rn, ri, rj, rz):
+        return (
+            (
+                (
+                    (
+                        (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
+                        | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
+                    )
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                )
+                | (
+                    (
+                        (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
+                        | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
+                    )
+                    & (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                )
+            )
+            # & (rz[:, :, None] == rz[:, None, :])
+            & (n[None, :, None] < rn[:, None, None])
+            & (n[None, None, :] < n[None, :, None])
+        )
+
     def sample_rworld_states(self, N=1000):
-        nb_rec_max = 5
         while True:
-            rn = torch.randint(nb_rec_max - 1, (N,)) + 2
-            ri = torch.randint(self.height, (N, nb_rec_max, 2)).sort(dim=2).values
-            rj = torch.randint(self.width, (N, nb_rec_max, 2)).sort(dim=2).values
-            rz = torch.randint(2, (N, nb_rec_max))
-            rc = torch.randint(self.nb_colors - 1, (N, nb_rec_max)) + 1
-            n = torch.arange(nb_rec_max)
+            rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
+            ri = torch.randint(self.height, (N, self.nb_rec_max, 2)).sort(dim=2).values
+            rj = torch.randint(self.width, (N, self.nb_rec_max, 2)).sort(dim=2).values
+            rz = torch.randint(2, (N, self.nb_rec_max))
+            rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
+            n = torch.arange(self.nb_rec_max)
             nb_collisions = (
                 (
                     (ri[:, :, None, 0] <= ri[:, None, :, 1])
@@ -607,17 +634,34 @@ class Grids(problem.Problem):
                 self.rj = rj[no_collision]
                 self.rz = rz[no_collision]
                 self.rc = rc[no_collision]
+
+                nb_contact = (
+                    contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+                )
+
+                self.rcontact = nb_contact > 0
+                self.rfree = torch.full((self.rn.size(0),), True)
+
                 break
 
-    def task_rworld_change_color(self, A, f_A, B, f_B):
-        nb_rec = 3
-        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+    def get_recworld_state(self):
+        if not self.rfree.any():
+            self.sample_rworld_states()
+        k = torch.arange(self.rn.size(0))[self.rfree]
+        k = k[torch.randint(k.size(0), (1,))].item()
+        self.rfree[k] = False
+        return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
+
+    def draw_state(self, X, rn, ri, rj, rz, rc):
+        for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
+            X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
+
+    def task_recworld_immobile(self, A, f_A, B, f_B):
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec, prevent_overlap=True)
-            for n in range(nb_rec):
-                i1, j1, i2, j2 = r[n]
-                X[i1:i2, j1:j2] = c[n]
-                f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+            rn, ri, rj, rz, rc = self.get_recworld_state()
+            self.draw_state(X, rn, ri, rj, rz, rc)
+            ri += 1
+            self.draw_state(f_X, rn, ri, rj, rz, rc)
 
     ######################################################################
 
@@ -1703,8 +1747,6 @@ if __name__ == "__main__":
     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
 
     grids = Grids()
-    grids.sample_rworld_states()
-    exit(0)
 
     # nb = 5
     # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
@@ -1746,7 +1788,7 @@ if __name__ == "__main__":
 
     # for t in grids.all_tasks:
 
-    for t in [grids.task_science_tag]:
+    for t in [grids.task_recworld_immobile]:
         print(t.__name__)
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
@@ -1756,7 +1798,7 @@ if __name__ == "__main__":
             comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
diff --git a/main.py b/main.py
index 36b58e2..9a8bd43 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -478,29 +478,63 @@ c_quizzes_procedure = [
 
 
 def save_additional_results(models, science_w_quizzes):
+    # Save generated quizzes with the successive steps
+
     for model in models:
         recorder = []
 
         c_quizzes = quiz_machine.generate_c_quizzes(
-            32,
+            64,
             model_for_generation=model,
             procedure=c_quizzes_procedure,
             recorder=recorder,
         )
 
+        ##
+
+        probas = 0
+
+        for a in range(args.nb_averaging_rounds):
+            # This is nb_quizzes x nb_models
+
+            seq_logproba = quiz_machine.models_logprobas(
+                models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            ) + quiz_machine.models_logprobas(
+                models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+            )
+
+            probas += seq_logproba.exp()
+
+        probas /= args.nb_averaging_rounds
+
+        comments = []
+
+        for l in seq_logproba:
+            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+        ##
+
         c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
         predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
-        nrow = c_quizzes.size(1)
+        nb_steps = c_quizzes.size(1)
         c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
         predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
 
+        # We have comments only for the final quiz, not the successive
+        # steps, so we have to add nb_steps-1 empty comments
+
+        steps_comments = []
+        for c in comments:
+            steps_comments += [""] * (nb_steps - 1) + [c]
+
         filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir,
             filename,
             quizzes=c_quizzes,
             predicted_parts=predicted_parts,
-            nrow=nrow,
+            comments=steps_comments,
+            nrow=nb_steps * 2,  # two quiz per row
         )
         log_string(f"wrote {filename}")