From fd52ecea8335abdad138bc5707a8ad3cfe3b79cf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 3 Aug 2024 19:51:30 +0200 Subject: [PATCH] Update. --- grids.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/grids.py b/grids.py index f195144..1a31a36 100755 --- a/grids.py +++ b/grids.py @@ -578,6 +578,7 @@ class Grids(problem.Problem): ###################################################################### def contact_matrices(self, rn, ri, rj, rz): + n = torch.arange(self.nb_rec_max) return ( ( ( @@ -604,9 +605,19 @@ class Grids(problem.Problem): def sample_rworld_states(self, N=1000): while True: + ri = ( + torch.randint(self.height - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + ri[:, :, 1] += 2 + rj = ( + torch.randint(self.width - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + rj[:, :, 1] += 2 rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2 - ri = torch.randint(self.height, (N, self.nb_rec_max, 2)).sort(dim=2).values - rj = torch.randint(self.width, (N, self.nb_rec_max, 2)).sort(dim=2).values rz = torch.randint(2, (N, self.nb_rec_max)) rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1 n = torch.arange(self.nb_rec_max) @@ -636,7 +647,7 @@ class Grids(problem.Problem): self.rc = rc[no_collision] nb_contact = ( - contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1) + self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1) ) self.rcontact = nb_contact > 0 -- 2.39.5