projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
58638ed
..
727b196
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-58,7
+58,7
@@
def masked_inplace_autoregression(
class Task:
class Task:
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
pass
def vocabulary_size(self):
pass
def vocabulary_size(self):
@@
-328,7
+328,7
@@
class PicoCLVR(Task):
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
@@
-1473,6
+1473,8
@@
class Grid(Task):
nb_test_samples,
batch_size,
size,
nb_test_samples,
batch_size,
size,
+ nb_shapes,
+ nb_colors,
logger=None,
device=torch.device("cpu"),
):
logger=None,
device=torch.device("cpu"),
):
@@
-1480,7
+1482,9
@@
class Grid(Task):
self.device = device
self.batch_size = batch_size
self.device = device
self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(size=size)
+ self.grid_factory = grid.GridFactory(
+ size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+ )
if logger is not None:
logger(
if logger is not None:
logger(
@@
-1515,11
+1519,13
@@
class Grid(Task):
self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=
f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=
desc
):
yield self.trim(batch)
):
yield self.trim(batch)
@@
-1618,11
+1624,13
@@
class QMLP(Task):
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=
f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=
desc
):
yield batch
):
yield batch