parser.add_argument("--nb_models", type=int, default=5)
+parser.add_argument("--proba_plasticity", type=float, default=0.0)
+
parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
######################################################################
-def new_model(i):
+def new_model(id=-1):
if args.model_type == "standard":
model_constructor = attae.AttentionAE
elif args.model_type == "functional":
dropout=args.dropout,
)
- model.id = i
+ model.id = id
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.test_accuracy = 0.0
model.nb_epochs = 0
######################################################################
+
+def inject_plasticity(model, proba):
+ if proba <= 0:
+ return
+
+ dummy = new_model()
+
+ with torch.no_grad():
+ for p, q in zip(mode.parameters(), dummy.parameters()):
+ mask = (torch.rand(p.size()) <= proba).long()
+ p[...] = (1 - mask) * p + mmask * q
+
+
+######################################################################
+
+chunk_size = 100
+
problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // chunk_size,
+ chunk_size=chunk_size,
nb_threads=args.nb_threads,
tasks=args.grids_world_tasks,
)
for i in range(args.nb_models):
model = new_model(i)
- # model = torch.compile(model)
+ model = torch.compile(model)
models.append(model)
test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
for model in models:
+ inject_plasticity(model, args.proba_plasticity)
model.test_accuracy = 0
if train_c_quizzes is None: