Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index b6f2783..3b29d01 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -183,7 +183,7 @@ for n in vars(args):
 ######################################################################
 
 if args.check:
 ######################################################################
 
 if args.check:
-    args.nb_train_samples = 500
+    args.nb_train_samples = 2500
     args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
     args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
@@ -404,7 +404,7 @@ if args.check:
     nb_new_quizzes_for_test = 10
 
 for n_epoch in range(args.nb_epochs):
     nb_new_quizzes_for_test = 10
 
 for n_epoch in range(args.nb_epochs):
-    a = [(model.id, model.main_test_accuracy.item()) for model in models]
+    a = [(model.id, float(model.main_test_accuracy)) for model in models]
     a.sort(key=lambda p: p[0])
     log_string(f"current accuracies {a}")
 
     a.sort(key=lambda p: p[0])
     log_string(f"current accuracies {a}")