# import torch.multiprocessing as mp
-# torch.set_float32_matmul_precision("high")
+torch.set_float32_matmul_precision("high")
# torch.set_default_dtype(torch.bfloat16)
with_perturbations=True,
local_device=local_device,
)
+ nb_mistakes = (result != quizzes).long().sum(dim=1)
nb_correct += (nb_mistakes == 0).long()
+
result = predict_full(
model=model,
input=quizzes,
for args in arguments:
# To get a different sequence between threads
- log_string(f"dummy_rand {torch.rand(1)}")
+ # log_string(f"dummy_rand {torch.rand(1)}")
+ torch.rand(1)
t = threading.Thread(target=threadable_fun, daemon=True, args=args)
threads.append(t)
t.start()