From 915db2eb89076dca35cc89df3ad895ddf346475f Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 27 Jul 2022 16:15:39 +0200 Subject: [PATCH] Cleaning up more. --- main.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index 77b1b22..339d185 100755 --- a/main.py +++ b/main.py @@ -18,7 +18,6 @@ import mygpt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### - parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', @@ -148,7 +147,7 @@ class Task: def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass ###################################################################### @@ -202,12 +201,9 @@ class TaskPicoCLVR(Task): def batches(self, split = 'train'): assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + input = self.train_input if split == 'train' else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + yield batch def vocabulary_size(self): return len(self.token2id) @@ -215,14 +211,13 @@ class TaskPicoCLVR(Task): def generate(self, primer_descr, model, nb_tokens): results = autoregression( model, self.batch_size, - 1, nb_tokens, primer = descr2tensor(primer_descr), + nb_samples = 1, nb_tokens = nb_tokens, primer = descr2tensor(primer_descr), device = self.device ) return ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) - def produce_results(self, n_epoch, model, nb_tokens = None): - if nb_tokens is None: - nb_tokens = self.height * self.width + 3 + def produce_results(self, n_epoch, model): + nb_tokens = self.height * self.width + 3 result_descr = [ ] nb_per_primer = 8 @@ -316,7 +311,8 @@ class TaskWiki103(Task): def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): + nb_tokens = 50 file_name = f'result_wiki103_{n_epoch:04d}.txt' with open(file_name, 'w') as outfile: @@ -377,7 +373,8 @@ class TaskMNIST(Task): def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): + def produce_results(self, n_epoch, model): + nb_samples = 64 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., -- 2.39.5