Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 14:59:01 +0000 (16:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 14:59:01 +0000 (16:59 +0200)
main.py

diff --git a/main.py b/main.py
index 0056c76..cd9ec20 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -282,14 +282,18 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
             body = c_quizzes.repeat(n, 1)
             if n < c_quiz_multiplier:
                 tail = c_quizzes[
-                    torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+                    torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+                        : nb_samples // 2 - body.size(0)
+                    ]
                 ]
                 c_quizzes = torch.cat([body, tail], dim=0)
             else:
                 c_quizzes = body
 
         if c_quizzes.size(0) > nb_samples // 2:
-            i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+            i = torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[
+                : nb_samples // 2
+            ]
             c_quizzes = c_quizzes[i]
 
         w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))