projects
/
mygpt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Finalized PicoCLVR with "many colors".
[mygpt.git]
/
main.py
diff --git
a/main.py
b/main.py
index
a31284e
..
3bf7587
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-111,12
+111,20
@@
import picoclvr
class TaskPicoCLVR(Task):
class TaskPicoCLVR(Task):
- def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
+ def __init__(self, batch_size,
+ height = 6, width = 8, many_colors = False,
+ device = torch.device('cpu')):
+
self.batch_size = batch_size
self.device = device
nb = args.data_size if args.data_size > 0 else 250000
self.batch_size = batch_size
self.device = device
nb = args.data_size if args.data_size > 0 else 250000
- descr = picoclvr.generate(nb, height = height, width = width)
+ descr = picoclvr.generate(
+ nb,
+ height = height, width = width,
+ many_colors = many_colors
+ )
+
descr = [ s.strip().split(' ') for s in descr ]
l = max([ len(s) for s in descr ])
descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
descr = [ s.strip().split(' ') for s in descr ]
l = max([ len(s) for s in descr ])
descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]