X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=picoclvr.py;h=601bdf7dbe7bce3291b7b7e07cecc1b9587e94ca;hb=119d5e84350bcab97c06a5c30227a294ebadb3c3;hp=712da1760764b85d3639d7427e16d85a9553bb4b;hpb=046f35f38d629c9854104e855a53f0142449138f;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 712da17..601bdf7 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -83,6 +83,7 @@ def generate(nb, height = 6, width = 8, nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] + # color 0 is white and reserved for the background square_c = torch.randperm(nb_colors)[:nb_squares] + 1 square_i = square_position.div(width, rounding_mode = 'floor') square_j = square_position % width