- # problem_indexes = torch.randint(len(problems), (nb_samples,))
- # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
- # print(f"{nb_samples_per_problem}")
- # all_seq = []
- # for nb, p in zip(nb_samples_per_problem, problems):
- # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
- # return all_seq
+ # problem_indexes = torch.randint(len(problems), (nb_samples,))
+ # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ # print(f"{nb_samples_per_problem}")
+ # all_seq = []
+ # for nb, p in zip(nb_samples_per_problem, problems):
+ # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ # return all_seq