Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 06:11:55 +0000 (08:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 06:11:55 +0000 (08:11 +0200)
main.py

diff --git a/main.py b/main.py
index de04d5b..87e73ad 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -45,23 +45,6 @@ parser.add_argument("--train_batch_size", type=int, default=None)
 
 parser.add_argument("--eval_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_train_samples", type=int, default=50000)
-
-
-parser.add_argument("--nb_test_samples", type=int, default=10000)
-
-parser.add_argument("--nb_c_quizzes", type=int, default=5000)
-
-parser.add_argument("--c_quiz_multiplier", type=int, default=1)
-
-parser.add_argument("--learning_rate", type=float, default=5e-4)
-
-parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
-
-parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
-
-parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
-
 # ----------------------------------
 
 parser.add_argument("--model_type", type=str, default="standard")
@@ -80,6 +63,8 @@ parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.5)
 
+parser.add_argument("--learning_rate", type=float, default=5e-4)
+
 # ----------------------------------
 
 parser.add_argument("--nb_threads", type=int, default=1)
@@ -88,21 +73,35 @@ parser.add_argument("--gpus", type=str, default="all")
 
 # ----------------------------------
 
-parser.add_argument("--nb_models", type=int, default=5)
-
-parser.add_argument("--proba_plasticity", type=float, default=0.25)
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
 parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
 
+parser.add_argument("--nb_train_samples", type=int, default=50000)
+
+parser.add_argument("--nb_test_samples", type=int, default=10000)
+
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
+
+# ----------------------------------
+
+parser.add_argument("--nb_models", type=int, default=5)
+
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+parser.add_argument("--proba_plasticity", type=float, default=0.0)
+
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
 
-parser.add_argument("--proba_hint", type=float, default=0.25)
-
-parser.add_argument("--quizzes", type=str, default=None)
+parser.add_argument("--proba_hints", type=float, default=0.1)
 
 ######################################################################
 
@@ -318,7 +317,7 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 
 def add_hints_imt(imt_set):
     """Set every component of the mask to zero with probability
-    args.proba_hint, and for each component set to zero, copy the
+    args.proba_hints, and for each component set to zero, copy the
     corresponding value from the target into the input
 
     """
@@ -327,8 +326,9 @@ def add_hints_imt(imt_set):
     # t = h.sort(dim=1).values[:, args.nb_hints, None]
     # mask_hints = (h < t).long()
     mask_hints = (
-        torch.rand(input.size(), device=input.device) < args.proba_hint
+        torch.rand(input.size(), device=input.device) < args.proba_hints
     ).long() * masks
+
     masks = (1 - mask_hints) * masks
     input = (1 - mask_hints) * input + mask_hints * targets
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
@@ -610,8 +610,8 @@ def one_complete_epoch(
 
     quizzes = generate_quiz_set(
         args.nb_test_samples,
-        c_quizzes=None,
-        c_quizzes=test_c_quizzes,
+        c_quizzes=None,
+        c_quizzes=test_c_quizzes,
         c_quiz_multiplier=args.c_quiz_multiplier,
     )
     imt_set = samples_for_prediction_imt(quizzes.to(local_device))
@@ -966,9 +966,9 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ###
 ### test_c_quizzes = train_c_quizzes[nb_correct >= len(models)//2]
 ###
-### for model in models:
-###     inject_plasticity(model, args.proba_plasticity)
-###     model.test_accuracy = 0
+# for model in models:
+# inject_plasticity(model, args.proba_plasticity)
+# model.test_accuracy = 0
 #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
 for n_epoch in range(current_epoch, args.nb_epochs):