Update.
[mygpt.git] / picoclvr.py
index f4d7a65..601bdf7 100755 (executable)
@@ -71,7 +71,9 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
 
 ######################################################################
 
-def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False):
+def generate(nb, height = 6, width = 8,
+             max_nb_squares = 5, max_nb_statements = 10,
+             many_colors = False):
 
     nb_colors =  len(color_tokens) - 1 if many_colors else max_nb_squares
 
@@ -81,6 +83,7 @@ def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements =
 
         nb_squares = torch.randint(max_nb_squares, (1,)) + 1
         square_position = torch.randperm(height * width)[:nb_squares]
+        # color 0 is white and reserved for the background
         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
         square_i = square_position.div(width, rounding_mode = 'floor')
         square_j = square_position % width