X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=0cd306243fd79ee90728eb16340581aad077914f;hb=6045e9a7dd61f0dab60bd1c6ff71f6bd5c32778b;hp=5da39436749ac7dcff688a0fea476a8d507d3371;hpb=3f09462033feac19ad72ac1a4b8690e6330df22d;p=picoclvr.git diff --git a/picoclvr.py b/picoclvr.py index 5da3943..0cd3062 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -5,6 +5,7 @@ # Written by Francois Fleuret +import math import torch, torchvision import torch.nn.functional as F @@ -201,7 +202,12 @@ def generate( descr = [] for n in range(nb): - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background