X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=picoclvr.py;h=601bdf7dbe7bce3291b7b7e07cecc1b9587e94ca;hb=119d5e84350bcab97c06a5c30227a294ebadb3c3;hp=f4d7a65a65b9dade2d3d73386db37cfda4f653d0;hpb=97e9ce02762148ac3cdedf22336034fdc75754b2;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index f4d7a65..601bdf7 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -71,7 +71,9 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### -def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False): +def generate(nb, height = 6, width = 8, + max_nb_squares = 5, max_nb_statements = 10, + many_colors = False): nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares @@ -81,6 +83,7 @@ def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] + # color 0 is white and reserved for the background square_c = torch.randperm(nb_colors)[:nb_squares] + 1 square_i = square_position.div(width, rounding_mode = 'floor') square_j = square_position % width