X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=0cd306243fd79ee90728eb16340581aad077914f;hb=fdc61b7e50e029aac58b10f377acdce549532f84;hp=bd0470f6279b7bdf5bf1c715eb848e4751fc5810;hpb=eea23df18f107fc65c810261c7775a9393ef7c8e;p=picoclvr.git diff --git a/picoclvr.py b/picoclvr.py index bd0470f..0cd3062 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -5,6 +5,7 @@ # Written by Francois Fleuret +import math import torch, torchvision import torch.nn.functional as F @@ -196,14 +197,17 @@ def generate( nb_colors=5, pruner=None, ): - assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 descr = [] for n in range(nb): - - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background @@ -242,7 +246,6 @@ def generate( def descr2img(descr, height, width): - result = [] def token2color(t): @@ -268,7 +271,6 @@ def descr2img(descr, height, width): def descr2properties(descr, height, width): - if type(descr) == list: return [descr2properties(d, height, width) for d in descr] @@ -313,7 +315,6 @@ def descr2properties(descr, height, width): def nb_properties(descr, height, width, pruner=None): - if type(descr) == list: return [nb_properties(d, height, width, pruner) for d in descr]