projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
73f61bf
..
c5418b4
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-76,7
+76,7
@@
class Problem:
class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
- self.seq = torch.randint(10, (nb_se
q
, len_prompt + 1 + len_result))
+ self.seq = torch.randint(10, (nb_se
ntences
, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):
self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):
@@
-104,7
+104,8
@@
class ProblemLevel1(Problem):
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
- source = torch.randint(10, (nb, self.len_source))
+ # source = torch.randint(10, (nb, self.len_source))
+ source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
@@
-128,7
+129,8
@@
class ProblemLevel2(Problem):
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
- source1 = torch.randint(10, (nb, self.len_source))
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ # source1 = torch.randint(10, (nb, self.len_source))
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)