projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
1d49fa4
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Fri, 10 Jun 2022 09:18:26 +0000
(11:18 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Fri, 10 Jun 2022 09:18:26 +0000
(11:18 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
a6940f1
..
a31284e
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-136,10
+136,10
@@
class TaskPicoCLVR(Task):
def batches(self, split = 'train'):
assert split in { 'train', 'test' }
if split == 'train':
def batches(self, split = 'train'):
assert split in { 'train', 'test' }
if split == 'train':
- for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc =
'epoch
'):
+ for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc =
f'epoch-{split}
'):
yield batch
else:
yield batch
else:
- for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc =
'epoch
'):
+ for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc =
f'epoch-{split}
'):
yield batch
def vocabulary_size(self):
yield batch
def vocabulary_size(self):
@@
-237,7
+237,7
@@
class TaskWiki103(Task):
if args.data_size > 0:
data_iter = itertools.islice(data_iter, args.data_size)
if args.data_size > 0:
data_iter = itertools.islice(data_iter, args.data_size)
- return self.yield_batches(tqdm.tqdm(data_iter, desc =
'epoch
'))
+ return self.yield_batches(tqdm.tqdm(data_iter, desc =
f'epoch-{split}
'))
def vocabulary_size(self):
return len(self.vocab)
def vocabulary_size(self):
return len(self.vocab)
@@
-296,7
+296,7
@@
class TaskMNIST(Task):
data_input = data_set.data.view(-1, 28 * 28).long()
if args.data_size >= 0:
data_input = data_input[:args.data_size]
data_input = data_set.data.view(-1, 28 * 28).long()
if args.data_size >= 0:
data_input = data_input[:args.data_size]
- for batch in tqdm.tqdm(data_input.split(self.batch_size), desc =
'epoch
'):
+ for batch in tqdm.tqdm(data_input.split(self.batch_size), desc =
f'epoch-{split}
'):
yield batch
def vocabulary_size(self):
yield batch
def vocabulary_size(self):