projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
8b57cb2
..
9cd06ae
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-73,7
+73,12
@@
class Problem:
class ProblemByheart(Problem):
def __init__(self):
class ProblemByheart(Problem):
def __init__(self):
- pass
+ nb_seq, len_prompt, len_result = 100, 5, 5
+ self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+ self.seq[:, len_prompt] = -1
+
+ def generate_sequences(self, nb):
+ return self.seq[torch.randint(self.seq.size(0), (nb,))]
class SandBox(Task):
class SandBox(Task):
@@
-89,13
+94,23
@@
class SandBox(Task):
self.batch_size = batch_size
self.batch_size = batch_size
+ problems = [ProblemByheart()]
+ nb_common_codes = 100
+
def generate_sequences(nb_samples):
problem_indexes = torch.randint(len(problems), (nb_samples,))
nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
print(f"{nb_samples_per_problem}")
def generate_sequences(nb_samples):
problem_indexes = torch.randint(len(problems), (nb_samples,))
nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
print(f"{nb_samples_per_problem}")
+ all_seq = []
+ for nb, p in zip(nb_samples_per_problem, problems):
+ all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ return all_seq
+
+ train_seq = generate_sequences(nb_train_samples)
+ test_seq = generate_sequences(nb_test_samples)
- self.train_input = generate_sequences(nb_train_samples)
-
self.test_input = generate_sequences(nb_test_samples
)
+ for strain, stest in zip(train_seq, test_seq):
+
s = torch.cat((strain, stest), 0
)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1