X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=c5418b4bb616386d8f34e038ff96460dfafff585;hb=e781d77071fa26f393f50451f91c70f4a0850ca5;hp=73f61bf102dd46387bde2ecd99f0d7a74c2d7250;hpb=a3211f96c7426a613b82a2de87d4dd70640e8f46;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 73f61bf..c5418b4 100755 --- a/tasks.py +++ b/tasks.py @@ -76,7 +76,7 @@ class Problem: class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): - self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): @@ -104,7 +104,8 @@ class ProblemLevel1(Problem): // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) ) % 10 marker1 = torch.full((nb, 1), 10) - source = torch.randint(10, (nb, self.len_source)) + # source = torch.randint(10, (nb, self.len_source)) + source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] marker2 = torch.full((nb, 1), 11) result = operators.bmm(source[:, :, None]).squeeze(-1) print(f"{nb_operators.dtype=} {marker1.dtype=}") @@ -128,7 +129,8 @@ class ProblemLevel2(Problem): torch.rand(nb, self.len_result, self.len_source).argmax(-1), num_classes=self.len_source, ) - source1 = torch.randint(10, (nb, self.len_source)) + source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] + # source1 = torch.randint(10, (nb, self.len_source)) marker1 = torch.full((nb, 1), 10) result1 = operators.bmm(source1[:, :, None]).squeeze(-1) marker2 = torch.full((nb, 1), 11)