Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 17:51:30 +0000 (19:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 17:51:30 +0000 (19:51 +0200)
grids.py

index f195144..1a31a36 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -578,6 +578,7 @@ class Grids(problem.Problem):
     ######################################################################
 
     def contact_matrices(self, rn, ri, rj, rz):
+        n = torch.arange(self.nb_rec_max)
         return (
             (
                 (
@@ -604,9 +605,19 @@ class Grids(problem.Problem):
 
     def sample_rworld_states(self, N=1000):
         while True:
+            ri = (
+                torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            ri[:, :, 1] += 2
+            rj = (
+                torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            rj[:, :, 1] += 2
             rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
-            ri = torch.randint(self.height, (N, self.nb_rec_max, 2)).sort(dim=2).values
-            rj = torch.randint(self.width, (N, self.nb_rec_max, 2)).sort(dim=2).values
             rz = torch.randint(2, (N, self.nb_rec_max))
             rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
             n = torch.arange(self.nb_rec_max)
@@ -636,7 +647,7 @@ class Grids(problem.Problem):
                 self.rc = rc[no_collision]
 
                 nb_contact = (
-                    contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+                    self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
                 )
 
                 self.rcontact = nb_contact > 0