X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=19517afaa05ea7c66eecdf7a1431cc6fe48e04f3;hb=b4593ceddfe4e149497b2ce8bbf2717ac3721337;hp=437439ef0f21a9897e1bec4b1fe369f88ddfd22b;hpb=5bc2d741ea7aac83005f099665b47f8a090931cb;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 437439e..19517af 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -93,7 +93,7 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): ###################################################################### -def generate(nb, height = 6, width = 8, +def generate(nb, height, width, max_nb_squares = 5, max_nb_properties = 10, many_colors = False): @@ -129,7 +129,7 @@ def generate(nb, height = 6, width = 8, ###################################################################### -def descr2img(descr, height = 6, width = 8): +def descr2img(descr, height, width): if type(descr) == list: return torch.cat([ descr2img(d, height, width) for d in descr ], 0) @@ -152,7 +152,7 @@ def descr2img(descr, height = 6, width = 8): ###################################################################### -def descr2properties(descr, height = 6, width = 8): +def descr2properties(descr, height, width): if type(descr) == list: return [ descr2properties(d, height, width) for d in descr ] @@ -181,9 +181,28 @@ def descr2properties(descr, height = 6, width = 8): ###################################################################### +def nb_missing_properties(descr, height, width): + if type(descr) == list: + return [ nb_missing_properties(d, height, width) for d in descr ] + + d = descr.split('', 1) + if len(d) == 0: return 0 + d = d[0].strip().split('') + d = [ x.strip() for x in d ] + + requested_properties = set(d) + all_properties = set(descr2properties(descr, height, width)) + missing_properties = requested_properties - all_properties + + return (len(requested_properties), len(all_properties), len(missing_properties)) + +###################################################################### + if __name__ == '__main__': descr = generate(nb = 5) - print(descr2properties(descr)) + + #print(descr2properties(descr)) + print(nb_missing_properties(descr)) with open('picoclvr_example.txt', 'w') as f: for d in descr: @@ -196,7 +215,7 @@ if __name__ == '__main__': import time start_time = time.perf_counter() - descr = generate(10000) + descr = generate(nb = 1000) end_time = time.perf_counter() print(f'{len(descr) / (end_time - start_time):.02f} samples per second')