######################################################################
-def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False):
+def generate(nb, height = 6, width = 8,
+ max_nb_squares = 5, max_nb_statements = 10,
+ many_colors = False):
nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
nb_squares = torch.randint(max_nb_squares, (1,)) + 1
square_position = torch.randperm(height * width)[:nb_squares]
+ # color 0 is white and reserved for the background
square_c = torch.randperm(nb_colors)[:nb_squares] + 1
square_i = square_position.div(width, rounding_mode = 'floor')
square_j = square_position % width