X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=3b29d01fa1dc444604fb15f7552f21ac887f924b;hb=87e2957c7f8262524e1dc627ad33725f387c7286;hp=09ae82345ec0258063f8a107a633b41b0c7779c8;hpb=4ec52fe66419a6e1d2b231108ccbb45902395fcc;p=culture.git diff --git a/main.py b/main.py index 09ae823..3b29d01 100755 --- a/main.py +++ b/main.py @@ -183,7 +183,7 @@ for n in vars(args): ###################################################################### if args.check: - args.nb_train_samples = 500 + args.nb_train_samples = 2500 args.nb_test_samples = 100 if args.physical_batch_size is None: @@ -360,7 +360,7 @@ def create_quizzes( task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False) task.save_image( - new_quizzes[:96], + new_quizzes[:72], args.result_dir, f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", log_string, @@ -404,7 +404,7 @@ if args.check: nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): - a = [(model.id, model.main_test_accuracy) for model in models] + a = [(model.id, float(model.main_test_accuracy)) for model in models] a.sort(key=lambda p: p[0]) log_string(f"current accuracies {a}")