Update.
[picoclvr.git] / picoclvr.py
index bd0470f..0cd3062 100755 (executable)
@@ -5,6 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import math
 import torch, torchvision
 import torch.nn.functional as F
 
@@ -196,14 +197,17 @@ def generate(
     nb_colors=5,
     pruner=None,
 ):
-
     assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
 
     descr = []
 
     for n in range(nb):
-
-        nb_squares = torch.randint(max_nb_squares, (1,)) + 1
+        # we want uniform over the combinations of 1 to max_nb_squares
+        # pixels of nb_colors
+        logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        nb_squares = dist.sample((1,)) + 1
+        # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
         square_position = torch.randperm(height * width)[:nb_squares]
 
         # color 0 is white and reserved for the background
@@ -242,7 +246,6 @@ def generate(
 
 
 def descr2img(descr, height, width):
-
     result = []
 
     def token2color(t):
@@ -268,7 +271,6 @@ def descr2img(descr, height, width):
 
 
 def descr2properties(descr, height, width):
-
     if type(descr) == list:
         return [descr2properties(d, height, width) for d in descr]
 
@@ -313,7 +315,6 @@ def descr2properties(descr, height, width):
 
 
 def nb_properties(descr, height, width, pruner=None):
-
     if type(descr) == list:
         return [nb_properties(d, height, width, pruner) for d in descr]