Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 09:25:06 +0000 (11:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 09:25:06 +0000 (11:25 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index b57c512..e058822 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -73,6 +73,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--nb_gpts", type=int, default=5)
+
 parser.add_argument("--check", action="store_true", default=False)
 
 ######################################################################
@@ -185,9 +187,9 @@ for n in vars(args):
 
 ######################################################################
 
-if args.test:
-    args.nb_train_samples = 1000
-    args.nb_test_samples = 25
+if args.check:
+    args.nb_train_samples = 500
+    args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
@@ -578,7 +580,7 @@ def create_quizzes(
     task.save_image(
         new_quizzes[:96],
         args.result_dir,
-        f"world_new_{n_epoch:04d}_{model.id:02d}.png",
+        f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
         log_string,
     )
 
@@ -587,7 +589,7 @@ def create_quizzes(
 
 models = []
 
-for k in range(5):
+for k in range(args.nb_gpts):
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -614,7 +616,7 @@ accuracy_to_make_quizzes = 0.975
 nb_new_quizzes_for_train = 1000
 nb_new_quizzes_for_test = 100
 
-if args.test:
+if args.check:
     accuracy_to_make_quizzes = 0.0
     nb_new_quizzes_for_train = 10
     nb_new_quizzes_for_test = 10
index 50d541b..8680ba1 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -220,7 +220,7 @@ class World(Task):
         self.save_image(
             result[:96],
             result_dir,
-            f"world_result_{n_epoch:04d}_{model.id:02d}.png",
+            f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
             logger,
         )
 
@@ -294,13 +294,8 @@ class World(Task):
                 device=self.device,
             )
 
-            nb_correct += (
-                (
-                    (new_quizzes == result).long()
-                    * (inverted_quizzes, inverted_result).long()
-                )
-                .min(dim=-1)
-                .values
-            )
+            nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
+                inverted_quizzes == inverted_result
+            ).long().min(dim=-1).values
 
         return new_quizzes, nb_correct