From 4084de83f541242ac815eba2c0883b84fd19141e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 16:14:12 +0200 Subject: [PATCH] Update. --- main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 7bdd09e..c08b04d 100755 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ import threading, subprocess # import torch.multiprocessing as mp -# torch.set_float32_matmul_precision("high") +torch.set_float32_matmul_precision("high") # torch.set_default_dtype(torch.bfloat16) @@ -673,7 +673,9 @@ def evaluate_quizzes(quizzes, models, local_device): with_perturbations=True, local_device=local_device, ) + nb_mistakes = (result != quizzes).long().sum(dim=1) nb_correct += (nb_mistakes == 0).long() + result = predict_full( model=model, input=quizzes, @@ -851,7 +853,8 @@ def multithread_execution(fun, arguments): for args in arguments: # To get a different sequence between threads - log_string(f"dummy_rand {torch.rand(1)}") + # log_string(f"dummy_rand {torch.rand(1)}") + torch.rand(1) t = threading.Thread(target=threadable_fun, daemon=True, args=args) threads.append(t) t.start() -- 2.39.5