def vocabulary_size(self):
return len(self.token2id)
- def produce_results(self, n_epoch, model):
+ def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
nb_tokens_to_generate = self.height * self.width + 3
result_descr = [ ]
- nb_per_primer = 8
- for primer_descr in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
+ for primer_descr in primers_descr:
results = autoregression(
model,
log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
- img = [
- picoclvr.descr2img(d, height = self.height, width = self.width)
- for d in result_descr
+ np=torch.tensor(np)
+ count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height = self.height, width = self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f'result_picoclvr_{n_epoch:04d}.png'
+ torchvision.utils.save_image(
+ img / 255.,
+ image_name, nrow = nb_per_primer, pad_value = 0.8
+ )
+ log_string(f'wrote {image_name}')
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ 'red above green <sep> green top <sep> blue right of red <img>',
+ 'there is red <sep> there is yellow <sep> there is blue <img>',
+ 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+ 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
]
- img = torch.cat(img, 0)
- image_name = f'result_picoclvr_{n_epoch:04d}.png'
- torchvision.utils.save_image(
- img / 255.,
- image_name, nrow = nb_per_primer, pad_value = 0.8
+ self.test_model(
+ n_epoch, model,
+ primers_descr,
+ nb_per_primer=8, generate_images=True
)
- log_string(f'wrote {image_name}')
+
+ # FAR TOO SLOW!!!
+
+ # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+ # count=self.test_model(
+ # n_epoch, model,
+ # test_primers_descr,
+ # nb_per_primer=1, generate_images=False
+ # )
+
+ # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+ # for i in range(count.size(0)):
+ # for j in range(count.size(1)):
+ # f.write(f'{count[i,j]}')
+ # f.write(" " if j<count.size(1)-1 else "\n")
######################################################################
def generate(nb, height, width,
max_nb_squares = 5, max_nb_properties = 10,
- nb_colors = 5):
+ nb_colors = 5,
+ pruning_criterion = None):
assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
+ if pruning_criterion is not None:
+ s = list(filter(pruning_criterion,s))
+
# pick at most max_nb_properties at random
nb_properties = torch.randint(max_nb_properties, (1,)) + 1
######################################################################
if __name__ == '__main__':
- descr = generate(nb = 5)
+ descr = generate(
+ nb = 5, height = 12, width = 16,
+ pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s))
+ )
- #print(descr2properties(descr))
- print(nb_properties(descr))
+ print(descr2properties(descr, height = 12, width = 16))
+ print(nb_properties(descr, height = 12, width = 16))
with open('picoclvr_example.txt', 'w') as f:
for d in descr:
f.write(f'{d}\n\n')
- img = descr2img(descr)
+ img = descr2img(descr, height = 12, width = 16)
torchvision.utils.save_image(img / 255.,
'picoclvr_example.png', nrow = 16, pad_value = 0.8)
import time
start_time = time.perf_counter()
- descr = generate(nb = 1000)
+ descr = generate(nb = 1000, height = 12, width = 16)
end_time = time.perf_counter()
print(f'{len(descr) / (end_time - start_time):.02f} samples per second')