def __init__(self):
nb_seq, len_prompt, len_result = 100, 5, 5
self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
- self.seq[:,len_prompt]=-1
+ self.seq[:, len_prompt] = -1
def generate_sequences(self, nb):
return self.seq[torch.randint(self.seq.size(0), (nb,))]
+
class SandBox(Task):
def __init__(
self,
self.batch_size = batch_size
- problems = [ ProblemByheart() ]
+ problems = [ProblemByheart()]
nb_common_codes = 100
def generate_sequences(nb_samples):
nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
print(f"{nb_samples_per_problem}")
all_seq = []
- for nb, p in zip(nb_samples_per_problem,problems):
+ for nb, p in zip(nb_samples_per_problem, problems):
all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
return all_seq
test_seq = generate_sequences(nb_test_samples)
for strain, stest in zip(train_seq, test_seq):
- s = torch.cat((strain,stest),0)
+ s = torch.cat((strain, stest), 0)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1