Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 19:41:33 +0000 (21:41 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 19:41:33 +0000 (21:41 +0200)
main.py

diff --git a/main.py b/main.py
index c4dcfb2..0670262 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -93,19 +93,19 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--max_fail_to_validate", type=int, default=2)
+parser.add_argument("--max_fail_to_validate", type=int, default=3)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_understands", type=float, default=0.95)
 
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--proba_not_understands", type=float, default=0.1)
 
 parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=1)
 
-parser.add_argument("--prompt_noise", type=float, default=0.0)
+parser.add_argument("--prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--nb_averaging_rounds", type=int, default=3)
 
@@ -122,15 +122,15 @@ grids_tasks = ", ".join(
 parser.add_argument(
     "--grids_world_tasks",
     type=str,
-    default=None,
-    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+    default="replace_color,translate,grow,frame",
+    help="A comma-separated subset of: " + grids_tasks + ".",
 )
 
 parser.add_argument(
     "--grids_science_tasks",
     type=str,
     default=None,
-    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+    help="A comma-separated subset of: " + grids_tasks + ", or None.",
 )
 
 ######################################################################
@@ -166,8 +166,8 @@ default_args = {
     "model": "37M",
     "batch_size": 25,
     "inference_batch_size": 50,
-    "nb_train_samples": 100000,
-    "nb_test_samples": 10000,
+    "nb_train_samples": 40000,
+    "nb_test_samples": 1000,
 }
 
 for k, v in default_args.items():
@@ -492,11 +492,16 @@ def model_transformer_cold(model):
     # pass
 
 
+warnings.warn("*********** novel procedure!!! **********", RuntimeWarning)
+
 c_quizzes_procedure = [
+    # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+    # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+    # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
     (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
-    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 0), model_transformer_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
     # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold),
     # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
 ]
@@ -1078,6 +1083,10 @@ if args.test == "quant":
             ),
         )
 
+        print(model)
+        exit(0)
+
+
 ######################################################################
 
 current_epoch = 0