X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=631d41bcb46939277a36cab2bdf7fb95aeac8d24;hb=f2ab5fd489adebe9b34ac825d39e41f13f287cb2;hp=4f704a0587ba5ffa282e1867f49c3535c7c152a1;hpb=7c79c0b140c88a529962945ec5b482fe90c55581;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 4f704a0..631d41b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -416,7 +416,7 @@ class QuizMachine: def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: