X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=5da39436749ac7dcff688a0fea476a8d507d3371;hb=687d5b2d9f465577665991b84faec7c789685271;hp=bd0470f6279b7bdf5bf1c715eb848e4751fc5810;hpb=eea23df18f107fc65c810261c7775a9393ef7c8e;p=picoclvr.git diff --git a/picoclvr.py b/picoclvr.py index bd0470f..5da3943 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -196,13 +196,11 @@ def generate( nb_colors=5, pruner=None, ): - assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 descr = [] for n in range(nb): - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] @@ -242,7 +240,6 @@ def generate( def descr2img(descr, height, width): - result = [] def token2color(t): @@ -268,7 +265,6 @@ def descr2img(descr, height, width): def descr2properties(descr, height, width): - if type(descr) == list: return [descr2properties(d, height, width) for d in descr] @@ -313,7 +309,6 @@ def descr2properties(descr, height, width): def nb_properties(descr, height, width, pruner=None): - if type(descr) == list: return [nb_properties(d, height, width, pruner) for d in descr]