+ rnd = torch.rand(nb, height, width)
+ coins = torch.zeros(nb, T, height, width, dtype=torch.int64)
+ rnd = rnd * (1 - wall.clamp(max=1))
+ for k in range(nb_coins):
+ coins[:, 0] = coins[:, 0] + (
+ rnd.flatten(1).argmax(dim=1)[:, None]
+ == torch.arange(rnd.flatten(1).size(1))[None, :]
+ ).long().reshape(rnd.size())
+
+ rnd = rnd * (1 - coins[:, 0].clamp(max=1))
+