X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=2e0ca36a3803b629e17d78963a2096a3fc6347fc;hb=59600257e0eda86816a43676c5ffbe598d78bdb5;hp=aa3acf038559195f4e68d7fd7e1625ffb41e332d;hpb=687d5b2d9f465577665991b84faec7c789685271;p=picoclvr.git diff --git a/problems.py b/problems.py index aa3acf0..2e0ca36 100755 --- a/problems.py +++ b/problems.py @@ -47,14 +47,22 @@ class ProblemTwoTargets(Problem): a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :]) a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :]) sequences = torch.cat( - (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1 + ( + s, + torch.full((nb, 1), 12), + a1, + torch.full((nb, 1), 12), + a2, + torch.full((nb, 1), 12), + ), + 1, ) ar_mask = (sequences == 12).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask def seq2str(self, seq): - return "".join("0123456789+-|"[x.item()] for x in seq) + return "".join("0123456789-+|"[x.item()] for x in seq) ####################