From: François Fleuret Date: Tue, 24 Sep 2024 06:17:04 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ad1dcd79ef870aebb31e9a7740bf4ed02db4a253;p=culture.git Update. --- diff --git a/main.py b/main.py index ed43b33..3424b27 100755 --- a/main.py +++ b/main.py @@ -90,6 +90,8 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_models", type=int, default=5) +parser.add_argument("--proba_plasticity", type=float, default=0.0) + parser.add_argument("--diffusion_nb_iterations", type=int, default=25) parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05) @@ -837,7 +839,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): ###################################################################### -def new_model(i): +def new_model(id=-1): if args.model_type == "standard": model_constructor = attae.AttentionAE elif args.model_type == "functional": @@ -855,7 +857,7 @@ def new_model(i): dropout=args.dropout, ) - model.id = i + model.id = id model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) model.test_accuracy = 0.0 model.nb_epochs = 0 @@ -865,9 +867,26 @@ def new_model(i): ###################################################################### + +def inject_plasticity(model, proba): + if proba <= 0: + return + + dummy = new_model() + + with torch.no_grad(): + for p, q in zip(mode.parameters(), dummy.parameters()): + mask = (torch.rand(p.size()) <= proba).long() + p[...] = (1 - mask) * p + mmask * q + + +###################################################################### + +chunk_size = 100 + problem = grids.Grids( - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // chunk_size, + chunk_size=chunk_size, nb_threads=args.nb_threads, tasks=args.grids_world_tasks, ) @@ -888,7 +907,7 @@ models = [] for i in range(args.nb_models): model = new_model(i) - # model = torch.compile(model) + model = torch.compile(model) models.append(model) @@ -998,6 +1017,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] for model in models: + inject_plasticity(model, args.proba_plasticity) model.test_accuracy = 0 if train_c_quizzes is None: