projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c3581ba
)
Oups
author
François Fleuret
<francois@fleuret.org>
Sat, 20 Apr 2024 06:59:20 +0000
(08:59 +0200)
committer
François Fleuret
<francois@fleuret.org>
Sat, 20 Apr 2024 06:59:20 +0000
(08:59 +0200)
tasks.py
patch
|
blob
|
history
diff --git
a/tasks.py
b/tasks.py
index
3ef64d7
..
c0ad5ff
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-63,7
+63,7
@@
def masked_inplace_autoregression(
class Task:
class Task:
- def batches(self, split="train"):
+ def batches(self, split="train"
, nb_to_use=-1, desc=None
):
pass
def vocabulary_size(self):
pass
def vocabulary_size(self):
@@
-489,7
+489,7
@@
class PicoCLVR(Task):
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train"
, nb_to_use=-1, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
@@
-1685,7
+1685,7
@@
class Grid(Task):
self.t_nul = self.token2id["#"]
self.t_true = self.token2id["true"]
self.t_false = self.token2id["false"]
self.t_nul = self.token2id["#"]
self.t_true = self.token2id["true"]
self.t_false = self.token2id["false"]
- self.t_pipe = self.token2id["|"]
+
#
self.t_pipe = self.token2id["|"]
# Tokenize the train and test sets
self.train_input = self.str2tensor(self.train_descr)
# Tokenize the train and test sets
self.train_input = self.str2tensor(self.train_descr)
@@
-1694,7
+1694,7
@@
class Grid(Task):
None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
)
None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
)
- def batches(self, split="train"):
+ def batches(self, split="train"
, nb_to_use=-1, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
@@
-1823,7
+1823,7
@@
class QMLP(Task):
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
- def batches(self, split="train"):
+ def batches(self, split="train"
, nb_to_use=-1, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(