From 5c751aa1bbfbcf42654f4626f81905acfa946c15 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 11:25:06 +0200 Subject: [PATCH] Update. --- main.py | 14 ++++++++------ tasks.py | 13 ++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index b57c512..e058822 100755 --- a/main.py +++ b/main.py @@ -73,6 +73,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--nb_gpts", type=int, default=5) + parser.add_argument("--check", action="store_true", default=False) ###################################################################### @@ -185,9 +187,9 @@ for n in vars(args): ###################################################################### -if args.test: - args.nb_train_samples = 1000 - args.nb_test_samples = 25 +if args.check: + args.nb_train_samples = 500 + args.nb_test_samples = 100 if args.physical_batch_size is None: args.physical_batch_size = args.batch_size @@ -578,7 +580,7 @@ def create_quizzes( task.save_image( new_quizzes[:96], args.result_dir, - f"world_new_{n_epoch:04d}_{model.id:02d}.png", + f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", log_string, ) @@ -587,7 +589,7 @@ def create_quizzes( models = [] -for k in range(5): +for k in range(args.nb_gpts): model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -614,7 +616,7 @@ accuracy_to_make_quizzes = 0.975 nb_new_quizzes_for_train = 1000 nb_new_quizzes_for_test = 100 -if args.test: +if args.check: accuracy_to_make_quizzes = 0.0 nb_new_quizzes_for_train = 10 nb_new_quizzes_for_test = 10 diff --git a/tasks.py b/tasks.py index 50d541b..8680ba1 100755 --- a/tasks.py +++ b/tasks.py @@ -220,7 +220,7 @@ class World(Task): self.save_image( result[:96], result_dir, - f"world_result_{n_epoch:04d}_{model.id:02d}.png", + f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, ) @@ -294,13 +294,8 @@ class World(Task): device=self.device, ) - nb_correct += ( - ( - (new_quizzes == result).long() - * (inverted_quizzes, inverted_result).long() - ) - .min(dim=-1) - .values - ) + nb_correct += (new_quizzes == result).long().min(dim=-1).values * ( + inverted_quizzes == inverted_result + ).long().min(dim=-1).values return new_quizzes, nb_correct -- 2.20.1