X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=3b29d01fa1dc444604fb15f7552f21ac887f924b;hb=87e2957c7f8262524e1dc627ad33725f387c7286;hp=45fa68c9eeaf0a0c802d0825ad82c4c1833614e6;hpb=9df1bd18f930f6b4a30b94fed6de684d5ceae3b7;p=culture.git diff --git a/main.py b/main.py index 45fa68c..3b29d01 100755 --- a/main.py +++ b/main.py @@ -183,7 +183,7 @@ for n in vars(args): ###################################################################### if args.check: - args.nb_train_samples = 500 + args.nb_train_samples = 2500 args.nb_test_samples = 100 if args.physical_batch_size is None: @@ -404,7 +404,7 @@ if args.check: nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): - a = [(model.id, model.main_test_accuracy) for model in models] + a = [(model.id, float(model.main_test_accuracy)) for model in models] a.sort(key=lambda p: p[0]) log_string(f"current accuracies {a}")