X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=3ecbf3aa40d3e055c8b329e94512565c3765a4cb;hb=82ddf9ca322e6fcc8f9364a696c26d15841d13d8;hp=8771e21bc48dc3203df327fcbe6e7c0f564213ef;hpb=6362bac43b1656b75db9acb1722134eab0d7191b;p=mygpt.git
diff --git a/picoclvr.py b/picoclvr.py
index 8771e21..3ecbf3a 100755
--- a/picoclvr.py
+++ b/picoclvr.py
@@ -93,11 +93,11 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c):
######################################################################
-def generate(nb, height = 6, width = 8,
+def generate(nb, height, width,
max_nb_squares = 5, max_nb_properties = 10,
- many_colors = False):
+ nb_colors = 5):
- nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
+ assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
descr = [ ]
@@ -129,7 +129,7 @@ def generate(nb, height = 6, width = 8,
######################################################################
-def descr2img(descr, height = 6, width = 8):
+def descr2img(descr, height, width):
if type(descr) == list:
return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
@@ -152,7 +152,7 @@ def descr2img(descr, height = 6, width = 8):
######################################################################
-def descr2properties(descr, height = 6, width = 8):
+def descr2properties(descr, height, width):
if type(descr) == list:
return [ descr2properties(d, height, width) for d in descr ]
@@ -163,6 +163,7 @@ def descr2properties(descr, height = 6, width = 8):
seen = {}
if len(d) != height * width: return []
+
for k, x in enumerate(d):
if x != color_names[0]:
if x in color_tokens:
@@ -171,9 +172,15 @@ def descr2properties(descr, height = 6, width = 8):
return []
seen[x] = (color_id[x], k // width, k % width)
- square_c = torch.tensor( [ x[0] for x in seen.values() ] )
- square_i = torch.tensor( [ x[1] for x in seen.values() ] )
- square_j = torch.tensor( [ x[2] for x in seen.values() ] )
+ square_infos = tuple(zip(*seen.values()))
+ if square_infos:
+ square_c = torch.tensor(square_infos[0])
+ square_i = torch.tensor(square_infos[1])
+ square_j = torch.tensor(square_infos[2])
+ else:
+ square_c = torch.tensor([])
+ square_i = torch.tensor([])
+ square_j = torch.tensor([])
s = all_properties(height, width, len(seen), square_i, square_j, square_c)
@@ -181,18 +188,20 @@ def descr2properties(descr, height = 6, width = 8):
######################################################################
-def nb_missing_properties(descr, height = 6, width = 8):
+def nb_properties(descr, height, width):
if type(descr) == list:
- return [ nb_missing_properties(d, height, width) for d in descr ]
+ return [ nb_properties(d, height, width) for d in descr ]
d = descr.split('', 1)
if len(d) == 0: return 0
d = d[0].strip().split('')
d = [ x.strip() for x in d ]
- missing_properties = set(d) - set(descr2properties(descr, height, width))
+ requested_properties = set(d)
+ all_properties = set(descr2properties(descr, height, width))
+ missing_properties = requested_properties - all_properties
- return len(missing_properties)
+ return (len(requested_properties), len(all_properties), len(missing_properties))
######################################################################
@@ -200,7 +209,7 @@ if __name__ == '__main__':
descr = generate(nb = 5)
#print(descr2properties(descr))
- print(nb_missing_properties(descr))
+ print(nb_properties(descr))
with open('picoclvr_example.txt', 'w') as f:
for d in descr: