From: Francois Fleuret Date: Fri, 10 Jun 2022 09:18:26 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=9425d322d40205b07000fd1bd23ef0696085c30b;p=mygpt.git Update. --- diff --git a/main.py b/main.py index a6940f1..a31284e 100755 --- a/main.py +++ b/main.py @@ -136,10 +136,10 @@ class TaskPicoCLVR(Task): def batches(self, split = 'train'): assert split in { 'train', 'test' } if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): + for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): yield batch else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): + for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): yield batch def vocabulary_size(self): @@ -237,7 +237,7 @@ class TaskWiki103(Task): if args.data_size > 0: data_iter = itertools.islice(data_iter, args.data_size) - return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) + return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}')) def vocabulary_size(self): return len(self.vocab) @@ -296,7 +296,7 @@ class TaskMNIST(Task): data_input = data_set.data.view(-1, 28 * 28).long() if args.data_size >= 0: data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): + for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'): yield batch def vocabulary_size(self):