# Written by Francois Fleuret <francois@fleuret.org>
+import math
import torch, torchvision
import torch.nn.functional as F
descr = []
for n in range(nb):
- nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+ # we want uniform over the combinations of 1 to max_nb_squares
+ # pixels of nb_colors
+ logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ nb_squares = dist.sample((1,)) + 1
+ # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
square_position = torch.randperm(height * width)[:nb_squares]
# color 0 is white and reserved for the background