projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
problems.py
diff --git
a/problems.py
b/problems.py
index
aa3acf0
..
2e0ca36
100755
(executable)
--- a/
problems.py
+++ b/
problems.py
@@
-47,14
+47,22
@@
class ProblemTwoTargets(Problem):
a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
sequences = torch.cat(
a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
sequences = torch.cat(
- (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,
)
ar_mask = (sequences == 12).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
)
ar_mask = (sequences == 12).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join("0123456789
+-
|"[x.item()] for x in seq)
+ return "".join("0123456789
-+
|"[x.item()] for x in seq)
####################
####################