Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:14:12 +0000 (16:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:14:12 +0000 (16:14 +0200)
main.py

diff --git a/main.py b/main.py
index 7bdd09e..c08b04d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -17,7 +17,7 @@ import threading, subprocess
 
 # import torch.multiprocessing as mp
 
-torch.set_float32_matmul_precision("high")
+torch.set_float32_matmul_precision("high")
 
 # torch.set_default_dtype(torch.bfloat16)
 
@@ -673,7 +673,9 @@ def evaluate_quizzes(quizzes, models, local_device):
             with_perturbations=True,
             local_device=local_device,
         )
+        nb_mistakes = (result != quizzes).long().sum(dim=1)
         nb_correct += (nb_mistakes == 0).long()
+
         result = predict_full(
             model=model,
             input=quizzes,
@@ -851,7 +853,8 @@ def multithread_execution(fun, arguments):
 
     for args in arguments:
         # To get a different sequence between threads
-        log_string(f"dummy_rand {torch.rand(1)}")
+        # log_string(f"dummy_rand {torch.rand(1)}")
+        torch.rand(1)
         t = threading.Thread(target=threadable_fun, daemon=True, args=args)
         threads.append(t)
         t.start()