Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 97c7130..e058822 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -73,6 +73,10 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--nb_gpts", type=int, default=5)
+
+parser.add_argument("--check", action="store_true", default=False)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -183,6 +187,9 @@ for n in vars(args):
 
 ######################################################################
 
+if args.check:
+    args.nb_train_samples = 500
+    args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
@@ -573,7 +580,7 @@ def create_quizzes(
     task.save_image(
         new_quizzes[:96],
         args.result_dir,
-        f"world_new_{n_epoch:04d}_{model.id:02d}.png",
+        f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
         log_string,
     )
 
@@ -582,7 +589,7 @@ def create_quizzes(
 
 models = []
 
-for k in range(5):
+for k in range(args.nb_gpts):
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -606,6 +613,13 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 accuracy_to_make_quizzes = 0.975
+nb_new_quizzes_for_train = 1000
+nb_new_quizzes_for_test = 100
+
+if args.check:
+    accuracy_to_make_quizzes = 0.0
+    nb_new_quizzes_for_train = 10
+    nb_new_quizzes_for_test = 10
 
 for n_epoch in range(args.nb_epochs):
     # select the model with lowest accuracy
@@ -634,8 +648,8 @@ for n_epoch in range(args.nb_epochs):
             model,
             other_models,
             task,
-            nb_for_train=1000,
-            nb_for_test=100,
+            nb_for_train=nb_new_quizzes_for_train,
+            nb_for_test=nb_new_quizzes_for_test,
         )