X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=73f61bf102dd46387bde2ecd99f0d7a74c2d7250;hb=a3211f96c7426a613b82a2de87d4dd70640e8f46;hp=5ac78cbbbbf698d8a6c6afc5087fa14b684968e6;hpb=1be1638f9906a1071dc82ebc6f35f8fc0eb91a3d;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 5ac78cb..73f61bf 100755 --- a/tasks.py +++ b/tasks.py @@ -87,29 +87,29 @@ class ProblemLevel0(Problem): class ProblemLevel1(Problem): - def __init__(self, nb_operators=100, len_prompt=5, len_result=8): - self.len_prompt = len_prompt + def __init__(self, nb_operators=100, len_source=5, len_result=8): + self.len_source = len_source self.len_result = len_result self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 self.operators = F.one_hot( - torch.rand(nb_operators, len_result, len_prompt).argmax(-1), - num_classes=len_prompt, + torch.rand(nb_operators, len_result, len_source).argmax(-1), + num_classes=len_source, ) def generate_sequences(self, nb): - a = self.len_nb_operator - b = a + 1 + self.len_prompt - sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) nb_operators = torch.randint(self.operators.size(0), (nb,)) - sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10 - sequences[:, a] = 10 - sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) - sequences[:, b] = 11 - - o = self.operators[nb_operators] - p = sequences[:, a + 1 : b] - print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}") - sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1) + operators = self.operators[nb_operators] + nb_operators = ( + nb_operators[:, None] + // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) + ) % 10 + marker1 = torch.full((nb, 1), 10) + source = torch.randint(10, (nb, self.len_source)) + marker2 = torch.full((nb, 1), 11) + result = operators.bmm(source[:, :, None]).squeeze(-1) + print(f"{nb_operators.dtype=} {marker1.dtype=}") + sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) + print(f"{sequences.size()=}") ar_mask = (sequences == 11).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask @@ -118,6 +118,35 @@ class ProblemLevel1(Problem): return "".join("0123456789|>"[x.item()] for x in seq) +class ProblemLevel2(Problem): + def __init__(self, len_source=5, len_result=8): + self.len_source = len_source + self.len_result = len_result + + def generate_sequences(self, nb): + operators = F.one_hot( + torch.rand(nb, self.len_result, self.len_source).argmax(-1), + num_classes=self.len_source, + ) + source1 = torch.randint(10, (nb, self.len_source)) + marker1 = torch.full((nb, 1), 10) + result1 = operators.bmm(source1[:, :, None]).squeeze(-1) + marker2 = torch.full((nb, 1), 11) + source2 = torch.randint(10, (nb, self.len_source)) + marker3 = torch.full((nb, 1), 12) + result2 = operators.bmm(source2[:, :, None]).squeeze(-1) + + sequences = torch.cat( + (source1, marker1, result1, marker2, source2, marker3, result2), 1 + ) + ar_mask = (sequences == 12).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789>|~"[x.item()] for x in seq) + + ####################