From b5efc396f45c23b7de0fe11f618731ac2b900d99 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 15 Jul 2022 18:10:14 +0200 Subject: [PATCH] Update. --- picoclvr.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/picoclvr.py b/picoclvr.py index 26f53ab..774ae3b 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -93,7 +93,7 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): ###################################################################### def generate(nb, height = 6, width = 8, - max_nb_squares = 5, max_nb_statements = 10, + max_nb_squares = 5, max_nb_properties = 10, many_colors = False): nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares @@ -112,14 +112,14 @@ def generate(nb, height = 6, width = 8, img = [ 0 ] * height * width for k in range(nb_squares): img[square_position[k]] = square_c[k] - # generates all the true relations + # generates all the true properties s = all_properties(height, width, nb_squares, square_i, square_j, square_c) - # pick at most max_nb_statements at random + # pick at most max_nb_properties at random - nb_statements = torch.randint(max_nb_statements, (1,)) + 1 - s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] ) + nb_properties = torch.randint(max_nb_properties, (1,)) + 1 + s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] ) s += ' ' + ' '.join([ f'{color_names[n]}' for n in img ]) descr += [ s ] @@ -130,23 +130,22 @@ def generate(nb, height = 6, width = 8, def descr2img(descr, height = 6, width = 8): + if type(descr) == list: + return torch.cat([ descr2img(d) for d in descr ], 0) + def token2color(t): try: return color_tokens[t] except KeyError: return [ 128, 128, 128 ] - def img_descr(x): - u = x.split('', 1) - return u[1] if len(u) > 1 else '' - - img = torch.full((len(descr), 3, height, width), 255) - d = [ img_descr(x) for x in descr ] - d = [ u.strip().split(' ')[:height * width] for u in d ] - d = [ u + [ '' ] * (height * width - len(u)) for u in d ] - d = [ [ token2color(t) for t in u ] for u in d ] - img = torch.tensor(d).permute(0, 2, 1) - img = img.reshape(img.size(0), 3, height, width) + d = descr.split('', 1) + d = d[-1] if len(d) > 1 else '' + d = d.strip().split(' ')[:height * width] + d = d + [ '' ] * (height * width - len(d)) + d = [ token2color(t) for t in d ] + img = torch.tensor(d).permute(1, 0) + img = img.reshape(1, 3, height, width) return img -- 2.39.5