######################################################################
def autoregression(
- model,
+ model, batch_size,
nb_samples, nb_tokens_to_generate, starting_input = None,
device = torch.device('cpu')
):
first = starting_input.size(1)
results = torch.cat((starting_input, results), 1)
- for input in results.split(args.batch_size):
+ for input in results.split(batch_size):
for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
output = model(input)
logits = output[:, s]
descr = [ s.strip().split(' ') for s in descr ]
l = max([ len(s) for s in descr ])
+ #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
return descr
self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
+ # Tokenize the train and test sets
t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
self.train_input = torch.tensor(t, device = self.device)
t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
)
log_string(f'wrote {image_name}')
- nb_missing = sum( [
- x[2] for x in picoclvr.nb_missing_properties(
- descr,
- height = self.height, width = self.width
- )
- ] )
+ np = picoclvr.nb_properties(
+ descr,
+ height = self.height, width = self.width
+ )
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
- log_string(f'nb_missing {nb_missing / len(descr):.02f}')
+ log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(descr):.02f}')
######################################################################
return 256
def produce_results(self, n_epoch, model, nb_samples = 64):
- results = autoregression(model, nb_samples, 28 * 28, device = self.device)
+ results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
image_name = f'result_mnist_{n_epoch:04d}.png'
torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
image_name, nrow = 16, pad_value = 0.8)