projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
e7c2f75
..
c5418b4
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-104,7
+104,8
@@
class ProblemLevel1(Problem):
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
- source = torch.randint(10, (nb, self.len_source))
+ # source = torch.randint(10, (nb, self.len_source))
+ source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
@@
-128,7
+129,8
@@
class ProblemLevel2(Problem):
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
- source1 = torch.randint(10, (nb, self.len_source))
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ # source1 = torch.randint(10, (nb, self.len_source))
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)