projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
218ff36
..
57c6801
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-106,7
+106,7
@@
class SandBox(Task):
device
), self.test_ar_mask.to(device)
device
), self.test_ar_mask.to(device)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
@@
-579,7
+579,7
@@
class Maze(Task):
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-756,7
+756,7
@@
class Snake(Task):
self.device,
)
self.device,
)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-871,7
+871,7
@@
class Stack(Task):
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-1078,7
+1078,7
@@
class RPL(Task):
s = " ".join(seq)
logger(f"example_seq {s}")
s = " ".join(seq)
logger(f"example_seq {s}")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-1308,7
+1308,7
@@
class Expr(Task):
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
@@
-1639,7
+1639,7
@@
class QMLP(Task):
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
- self.nb_codes =
max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes =
(max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
def batches(self, split="train", desc=None):
assert split in {"train", "test"}