From: François Fleuret Date: Tue, 18 Jul 2023 20:03:24 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1be1638f9906a1071dc82ebc6f35f8fc0eb91a3d;p=picoclvr.git Update. --- diff --git a/tasks.py b/tasks.py index 332d6c5..5ac78cb 100755 --- a/tasks.py +++ b/tasks.py @@ -101,7 +101,7 @@ class ProblemLevel1(Problem): b = a + 1 + self.len_prompt sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) nb_operators = torch.randint(self.operators.size(0), (nb,)) - sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10 + sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10 sequences[:, a] = 10 sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) sequences[:, b] = 11 @@ -115,7 +115,7 @@ class ProblemLevel1(Problem): return sequences, ar_mask def seq2str(self, seq): - return "".join(self.id2char[x.item()] for x in seq) + return "".join("0123456789|>"[x.item()] for x in seq) ####################