X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=3cdd3747b75711d317a2bf11312148f69f02ceba;hb=9d4193312e06ed284b1368b7f4407f2b4f981c7a;hp=9e368c252d06ef9ec721799926e30b52c57773d2;hpb=408f2335af43590ee2d99c3286cbe3762c76887a;p=mygptrnn.git diff --git a/problems.py b/problems.py index 9e368c2..3cdd374 100755 --- a/problems.py +++ b/problems.py @@ -149,7 +149,13 @@ class ProblemMemory(Problem): return sequences, ar_mask def seq2str(self, seq): - return "".join(self.token_string[x.item()] for x in seq) + def decode(x): + if x < len(self.token_string): + return self.token_string[x] + else: + return "?" + + return "".join(decode(x.item()) for x in seq) class ProblemTwoTargets(Problem):