body = c_quizzes.repeat(n, 1)
if n < c_quiz_multiplier:
tail = c_quizzes[
- torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+ : nb_samples // 2 - body.size(0)
+ ]
]
c_quizzes = torch.cat([body, tail], dim=0)
else:
c_quizzes = body
if c_quizzes.size(0) > nb_samples // 2:
- i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ i = torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+ : nb_samples // 2
+ ]
c_quizzes = c_quizzes[i]
w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))