X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=b6f278359f149446b416cbb4a7d6f632f5e13f96;hb=36a1440d01cc15643849f5ba421f89ac403ccd82;hp=09ae82345ec0258063f8a107a633b41b0c7779c8;hpb=4ec52fe66419a6e1d2b231108ccbb45902395fcc;p=culture.git diff --git a/main.py b/main.py index 09ae823..b6f2783 100755 --- a/main.py +++ b/main.py @@ -360,7 +360,7 @@ def create_quizzes( task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False) task.save_image( - new_quizzes[:96], + new_quizzes[:72], args.result_dir, f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", log_string, @@ -404,7 +404,7 @@ if args.check: nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): - a = [(model.id, model.main_test_accuracy) for model in models] + a = [(model.id, model.main_test_accuracy.item()) for model in models] a.sort(key=lambda p: p[0]) log_string(f"current accuracies {a}")