X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=5f3e8cffe5bce08167141cb79e116e1f5a9e336a;hb=1ad9ea3cca4489b07bad8521966382f66a493eea;hp=77b1b226f95c4d9a11dd6fe5990d575ae0795399;hpb=b4c255babaae72d6a03b4c8e8e7e25f6ab0a19a0;p=mygpt.git diff --git a/main.py b/main.py index 77b1b22..5f3e8cf 100755 --- a/main.py +++ b/main.py @@ -18,7 +18,6 @@ import mygpt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### - parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', @@ -148,7 +147,7 @@ class Task: def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass ###################################################################### @@ -157,7 +156,7 @@ import picoclvr class TaskPicoCLVR(Task): - def descr2tensor(self, descr): + def tensorize(self, descr): t = [ [ self.token2id[u] for u in s ] for s in descr ] return torch.tensor(t, device = self.device) @@ -174,8 +173,8 @@ class TaskPicoCLVR(Task): descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) - #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] return descr @@ -197,32 +196,20 @@ class TaskPicoCLVR(Task): self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) # Tokenize the train and test sets - self.train_input = descr2tensor(self.train_descr) - self.test_input = descr2tensor(self.test_descr) + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) def batches(self, split = 'train'): assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + input = self.train_input if split == 'train' else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + yield batch def vocabulary_size(self): return len(self.token2id) - def generate(self, primer_descr, model, nb_tokens): - results = autoregression( - model, self.batch_size, - 1, nb_tokens, primer = descr2tensor(primer_descr), - device = self.device - ) - return ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) - - def produce_results(self, n_epoch, model, nb_tokens = None): - if nb_tokens is None: - nb_tokens = self.height * self.width + 3 + def produce_results(self, n_epoch, model): + nb_tokens = self.height * self.width + 3 result_descr = [ ] nb_per_primer = 8 @@ -234,10 +221,20 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - result_descr.append(self.generate(primer_descr, model, nb_tokens)) + results = autoregression( + model, self.batch_size, + nb_samples = 1, nb_tokens = nb_tokens, + primer = self.tensorize(primer_descr), + device = self.device + ) + r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) + result_descr.append(r) + + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] - img = [ picoclvr.descr2img(d, height = self.height, width = self.width) - for d in result_descr ] img = torch.cat(img, 0) image_name = f'result_picoclvr_{n_epoch:04d}.png' torchvision.utils.save_image( @@ -281,7 +278,7 @@ class TaskWiki103(Task): self.vocab = torchtext.vocab.build_vocab_from_iterator( yield_tokens(), - specials = [ '', '' ], + specials = [ '', '' ], min_freq = self.min_freq ) @@ -289,7 +286,7 @@ class TaskWiki103(Task): def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) def yield_batches(self, ds): s = [ ] @@ -316,7 +313,8 @@ class TaskWiki103(Task): def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): + nb_tokens = 50 file_name = f'result_wiki103_{n_epoch:04d}.txt' with open(file_name, 'w') as outfile: @@ -346,7 +344,7 @@ class TaskWiki103(Task): else: t_next = logits.argmax() t_generated.append(self.vocab.lookup_token(t_next)) - if t_generated[-1] == '': break + if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -377,7 +375,8 @@ class TaskMNIST(Task): def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): + def produce_results(self, n_epoch, model): + nb_samples = 64 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,