projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
c7348d5
..
0ab1823
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-1429,7
+1429,7
@@
class Grid(Task):
def tensorize(self, descr):
token_descr = [s.strip().split(" ") for s in descr]
l = max([len(s) for s in token_descr])
def tensorize(self, descr):
token_descr = [s.strip().split(" ") for s in descr]
l = max([len(s) for s in token_descr])
- token_descr = [s + ["
<nul>
"] * (l - len(s)) for s in token_descr]
+ token_descr = [s + ["
#
"] * (l - len(s)) for s in token_descr]
id_descr = [[self.token2id[u] for u in s] for s in token_descr]
return torch.tensor(id_descr, device=self.device)
id_descr = [[self.token2id[u] for u in s] for s in token_descr]
return torch.tensor(id_descr, device=self.device)
@@
-1440,7
+1440,7
@@
class Grid(Task):
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
# elements are trimed according to the triming for the first
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
# elements are trimed according to the triming for the first
- def trim(self, z, token="
<nul>
"):
+ def trim(self, z, token="
#
"):
n = self.token2id[token]
if type(z) == tuple:
x = z[0]
n = self.token2id[token]
if type(z) == tuple:
x = z[0]
@@
-1483,7
+1483,7
@@
class Grid(Task):
)
# Build the tokenizer
)
# Build the tokenizer
- tokens =
{}
+ tokens =
set()
for d in [self.train_descr, self.test_descr]:
for s in d:
for t in s.strip().split(" "):
for d in [self.train_descr, self.test_descr]:
for s in d:
for t in s.strip().split(" "):
@@
-1492,10
+1492,10
@@
class Grid(Task):
# the same descr
tokens = list(tokens)
tokens.sort()
# the same descr
tokens = list(tokens)
tokens.sort()
- tokens = ["
<nul>
"] + tokens
+ tokens = ["
#
"] + tokens
self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
- self.t_nul = self.token2id["
<nul>
"]
+ self.t_nul = self.token2id["
#
"]
self.t_true = self.token2id["<true>"]
self.t_false = self.token2id["<false>"]
self.t_true = self.token2id["<true>"]
self.t_false = self.token2id["<false>"]