X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=3ecbf3aa40d3e055c8b329e94512565c3765a4cb;hb=82ddf9ca322e6fcc8f9364a696c26d15841d13d8;hp=19517afaa05ea7c66eecdf7a1431cc6fe48e04f3;hpb=b4593ceddfe4e149497b2ce8bbf2717ac3721337;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 19517af..3ecbf3a 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -95,9 +95,9 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): def generate(nb, height, width, max_nb_squares = 5, max_nb_properties = 10, - many_colors = False): + nb_colors = 5): - nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares + assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 descr = [ ] @@ -163,6 +163,7 @@ def descr2properties(descr, height, width): seen = {} if len(d) != height * width: return [] + for k, x in enumerate(d): if x != color_names[0]: if x in color_tokens: @@ -171,9 +172,15 @@ def descr2properties(descr, height, width): return [] seen[x] = (color_id[x], k // width, k % width) - square_c = torch.tensor( [ x[0] for x in seen.values() ] ) - square_i = torch.tensor( [ x[1] for x in seen.values() ] ) - square_j = torch.tensor( [ x[2] for x in seen.values() ] ) + square_infos = tuple(zip(*seen.values())) + if square_infos: + square_c = torch.tensor(square_infos[0]) + square_i = torch.tensor(square_infos[1]) + square_j = torch.tensor(square_infos[2]) + else: + square_c = torch.tensor([]) + square_i = torch.tensor([]) + square_j = torch.tensor([]) s = all_properties(height, width, len(seen), square_i, square_j, square_c) @@ -181,9 +188,9 @@ def descr2properties(descr, height, width): ###################################################################### -def nb_missing_properties(descr, height, width): +def nb_properties(descr, height, width): if type(descr) == list: - return [ nb_missing_properties(d, height, width) for d in descr ] + return [ nb_properties(d, height, width) for d in descr ] d = descr.split('', 1) if len(d) == 0: return 0 @@ -202,7 +209,7 @@ if __name__ == '__main__': descr = generate(nb = 5) #print(descr2properties(descr)) - print(nb_missing_properties(descr)) + print(nb_properties(descr)) with open('picoclvr_example.txt', 'w') as f: for d in descr: