+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences==10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ # problems = [ProblemByheart()]
+ # nb_common_codes = 100
+
+ # def generate_sequences(nb_samples):
+ # problem_indexes = torch.randint(len(problems), (nb_samples,))
+ # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ # print(f"{nb_samples_per_problem}")
+ # all_seq = []
+ # for nb, p in zip(nb_samples_per_problem, problems):
+ # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ # return all_seq
+
+ # for strain, stest in zip(train_seq, test_seq):
+ # s = torch.cat((strain, stest), 0)