Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 24 Sep 2024 06:17:04 +0000 (08:17 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 24 Sep 2024 06:17:04 +0000 (08:17 +0200)
main.py

diff --git a/main.py b/main.py
index ed43b33..3424b27 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -90,6 +90,8 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_models", type=int, default=5)
 
+parser.add_argument("--proba_plasticity", type=float, default=0.0)
+
 parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
 parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
@@ -837,7 +839,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
 ######################################################################
 
 
-def new_model(i):
+def new_model(id=-1):
     if args.model_type == "standard":
         model_constructor = attae.AttentionAE
     elif args.model_type == "functional":
@@ -855,7 +857,7 @@ def new_model(i):
         dropout=args.dropout,
     )
 
-    model.id = i
+    model.id = id
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
     model.test_accuracy = 0.0
     model.nb_epochs = 0
@@ -865,9 +867,26 @@ def new_model(i):
 
 ######################################################################
 
+
+def inject_plasticity(model, proba):
+    if proba <= 0:
+        return
+
+    dummy = new_model()
+
+    with torch.no_grad():
+        for p, q in zip(mode.parameters(), dummy.parameters()):
+            mask = (torch.rand(p.size()) <= proba).long()
+            p[...] = (1 - mask) * p + mmask * q
+
+
+######################################################################
+
+chunk_size = 100
+
 problem = grids.Grids(
-    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-    chunk_size=100,
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // chunk_size,
+    chunk_size=chunk_size,
     nb_threads=args.nb_threads,
     tasks=args.grids_world_tasks,
 )
@@ -888,7 +907,7 @@ models = []
 
 for i in range(args.nb_models):
     model = new_model(i)
-    model = torch.compile(model)
+    model = torch.compile(model)
 
     models.append(model)
 
@@ -998,6 +1017,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
 
         for model in models:
+            inject_plasticity(model, args.proba_plasticity)
             model.test_accuracy = 0
 
     if train_c_quizzes is None: