From 6c6c62671ded6262be867ca1ac43c7fefd812b90 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 16:59:01 +0200 Subject: [PATCH] Update. --- main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 0056c76..cd9ec20 100755 --- a/main.py +++ b/main.py @@ -282,14 +282,18 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): body = c_quizzes.repeat(n, 1) if n < c_quiz_multiplier: tail = c_quizzes[ - torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)] + torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[ + : nb_samples // 2 - body.size(0) + ] ] c_quizzes = torch.cat([body, tail], dim=0) else: c_quizzes = body if c_quizzes.size(0) > nb_samples // 2: - i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] + i = torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[ + : nb_samples // 2 + ] c_quizzes = c_quizzes[i] w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) -- 2.39.5