X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=2d57505ef078204f3c85ba4da98a0253324e1e6e;hb=086ec8f8d2ffeaac270fbedd991bb79122db7fdf;hp=8771e21bc48dc3203df327fcbe6e7c0f564213ef;hpb=6362bac43b1656b75db9acb1722134eab0d7191b;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 8771e21..2d57505 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -93,11 +93,11 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): ###################################################################### -def generate(nb, height = 6, width = 8, +def generate(nb, height, width, max_nb_squares = 5, max_nb_properties = 10, - many_colors = False): + nb_colors = 5): - nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares + assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 descr = [ ] @@ -129,7 +129,7 @@ def generate(nb, height = 6, width = 8, ###################################################################### -def descr2img(descr, height = 6, width = 8): +def descr2img(descr, height, width): if type(descr) == list: return torch.cat([ descr2img(d, height, width) for d in descr ], 0) @@ -152,7 +152,7 @@ def descr2img(descr, height = 6, width = 8): ###################################################################### -def descr2properties(descr, height = 6, width = 8): +def descr2properties(descr, height, width): if type(descr) == list: return [ descr2properties(d, height, width) for d in descr ] @@ -181,7 +181,7 @@ def descr2properties(descr, height = 6, width = 8): ###################################################################### -def nb_missing_properties(descr, height = 6, width = 8): +def nb_missing_properties(descr, height, width): if type(descr) == list: return [ nb_missing_properties(d, height, width) for d in descr ] @@ -190,9 +190,11 @@ def nb_missing_properties(descr, height = 6, width = 8): d = d[0].strip().split('') d = [ x.strip() for x in d ] - missing_properties = set(d) - set(descr2properties(descr, height, width)) + requested_properties = set(d) + all_properties = set(descr2properties(descr, height, width)) + missing_properties = requested_properties - all_properties - return len(missing_properties) + return (len(requested_properties), len(all_properties), len(missing_properties)) ######################################################################