projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
58638ed
..
57c6801
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-58,7
+58,7
@@
def masked_inplace_autoregression(
class Task:
class Task:
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
pass
def vocabulary_size(self):
pass
def vocabulary_size(self):
@@
-106,7
+106,7
@@
class SandBox(Task):
device
), self.test_ar_mask.to(device)
device
), self.test_ar_mask.to(device)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
@@
-250,7
+250,13
@@
class PicoCLVR(Task):
# Make a list of strings from a tensor
def detensorize(self, x):
# Make a list of strings from a tensor
def detensorize(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+ def id2token(t):
+ try:
+ return self.id2token[t.item()]
+ except KeyError:
+ return "?"
+
+ return [" ".join([id2token(t) for t in r]) for r in x]
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
@@
-328,7
+334,7
@@
class PicoCLVR(Task):
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
self.train_input = self.tensorize(self.train_descr)
self.test_input = self.tensorize(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
for batch in tqdm.tqdm(
@@
-573,7
+579,7
@@
class Maze(Task):
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-750,7
+756,7
@@
class Snake(Task):
self.device,
)
self.device,
)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-865,7
+871,7
@@
class Stack(Task):
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-888,7
+894,10
@@
class Stack(Task):
def compute_nb_correct(input):
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
def compute_nb_correct(input):
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+
ar_mask = (result != input).long()
ar_mask = (result != input).long()
+ result *= 1 - ar_mask
+
masked_inplace_autoregression(
model,
self.batch_size,
masked_inplace_autoregression(
model,
self.batch_size,
@@
-923,10
+932,12
@@
class Stack(Task):
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
- # for n in range(result.size(0)):
- # logger(
- # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
- # )
+ for n in range(result.size(0)):
+ logger(
+ f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+
+ result *= 1 - ar_mask
masked_inplace_autoregression(
model,
masked_inplace_autoregression(
model,
@@
-1067,7
+1078,7
@@
class RPL(Task):
s = " ".join(seq)
logger(f"example_seq {s}")
s = " ".join(seq)
logger(f"example_seq {s}")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-1297,7
+1308,7
@@
class Expr(Task):
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-1448,7
+1459,13
@@
class Grid(Task):
# Make a list of strings from a tensor
def tensor2str(self, x):
# Make a list of strings from a tensor
def tensor2str(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+ def id2token(t):
+ try:
+ return self.id2token[t.item()]
+ except KeyError:
+ return "?"
+
+ return [" ".join([id2token(t) for t in r]) for r in x]
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
@@
-1473,6
+1490,8
@@
class Grid(Task):
nb_test_samples,
batch_size,
size,
nb_test_samples,
batch_size,
size,
+ nb_shapes,
+ nb_colors,
logger=None,
device=torch.device("cpu"),
):
logger=None,
device=torch.device("cpu"),
):
@@
-1480,7
+1499,9
@@
class Grid(Task):
self.device = device
self.batch_size = batch_size
self.device = device
self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(size=size)
+ self.grid_factory = grid.GridFactory(
+ size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+ )
if logger is not None:
logger(
if logger is not None:
logger(
@@
-1515,11
+1536,13
@@
class Grid(Task):
self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=
f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=
desc
):
yield self.trim(batch)
):
yield self.trim(batch)
@@
-1616,13
+1639,15
@@
class QMLP(Task):
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
- def batches(self, split="train"):
+ def batches(self, split="train"
, desc=None
):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=
f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=
desc
):
yield batch
):
yield batch