######################################################################
def contact_matrices(self, rn, ri, rj, rz):
+ n = torch.arange(self.nb_rec_max)
return (
(
(
def sample_rworld_states(self, N=1000):
while True:
+ ri = (
+ torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ ri[:, :, 1] += 2
+ rj = (
+ torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+ .sort(dim=2)
+ .values
+ )
+ rj[:, :, 1] += 2
rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
- ri = torch.randint(self.height, (N, self.nb_rec_max, 2)).sort(dim=2).values
- rj = torch.randint(self.width, (N, self.nb_rec_max, 2)).sort(dim=2).values
rz = torch.randint(2, (N, self.nb_rec_max))
rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
n = torch.arange(self.nb_rec_max)
self.rc = rc[no_collision]
nb_contact = (
- contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+ self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
)
self.rcontact = nb_contact > 0