X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=059e352d09d16011bd7daee9e657405d676a5654;hb=62533ba50393866c15b322074cad836684dd69e7;hp=3ecbf3aa40d3e055c8b329e94512565c3765a4cb;hpb=3ae0c8f3767e4285ab548e4548576a6ddf6003bb;p=mygpt.git diff --git a/picoclvr.py b/picoclvr.py index 3ecbf3a..059e352 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -95,7 +95,8 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): def generate(nb, height, width, max_nb_squares = 5, max_nb_properties = 10, - nb_colors = 5): + nb_colors = 5, + pruning_criterion = None): assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 @@ -117,6 +118,9 @@ def generate(nb, height, width, s = all_properties(height, width, nb_squares, square_i, square_j, square_c) + if pruning_criterion is not None: + s = list(filter(pruning_criterion,s)) + # pick at most max_nb_properties at random nb_properties = torch.randint(max_nb_properties, (1,)) + 1 @@ -206,23 +210,26 @@ def nb_properties(descr, height, width): ###################################################################### if __name__ == '__main__': - descr = generate(nb = 5) + descr = generate( + nb = 5, height = 12, width = 16, + pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s)) + ) - #print(descr2properties(descr)) - print(nb_properties(descr)) + print(descr2properties(descr, height = 12, width = 16)) + print(nb_properties(descr, height = 12, width = 16)) with open('picoclvr_example.txt', 'w') as f: for d in descr: f.write(f'{d}\n\n') - img = descr2img(descr) + img = descr2img(descr, height = 12, width = 16) torchvision.utils.save_image(img / 255., 'picoclvr_example.png', nrow = 16, pad_value = 0.8) import time start_time = time.perf_counter() - descr = generate(nb = 1000) + descr = generate(nb = 1000, height = 12, width = 16) end_time = time.perf_counter() print(f'{len(descr) / (end_time - start_time):.02f} samples per second')