+ nb_seq, len_prompt, len_result = 100, 5, 5
+ self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+ self.seq[:, len_prompt] = 10
+
+ def generate_sequences(self, nb):
+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences==10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ # problems = [ProblemByheart()]
+ # nb_common_codes = 100
+
+ # def generate_sequences(nb_samples):
+ # problem_indexes = torch.randint(len(problems), (nb_samples,))
+ # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ # print(f"{nb_samples_per_problem}")
+ # all_seq = []
+ # for nb, p in zip(nb_samples_per_problem, problems):
+ # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ # return all_seq
+
+ # for strain, stest in zip(train_seq, test_seq):
+ # s = torch.cat((strain, stest), 0)