+ # we want uniform over the combinations of 1 to max_nb_squares
+ # pixels of nb_colors
+ logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ nb_squares = dist.sample((1,)) + 1
+ # nb_squares = torch.randint(max_nb_squares, (1,)) + 1