From fdb2071b7f6369df6e1a3ba6183e5e4db56ba8f7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 14 Sep 2024 21:15:37 +0200 Subject: [PATCH] Update. --- attae.py | 2 -- main.py | 42 +++++++++++++++++------------------------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/attae.py b/attae.py index 9a2f240..06deed2 100755 --- a/attae.py +++ b/attae.py @@ -51,8 +51,6 @@ def attention(q, k, v): return y -attention = torch.compile(attention) - ###################################################################### diff --git a/main.py b/main.py index 8010fa4..62cbd2f 100755 --- a/main.py +++ b/main.py @@ -121,8 +121,6 @@ parser.add_argument("--nb_hints", type=int, default=25) parser.add_argument("--nb_runs", type=int, default=1) -parser.add_argument("--dirty_debug", action="store_true", default=False) - parser.add_argument("--test", type=str, default=None) parser.add_argument("--quizzes", type=str, default=None) @@ -210,7 +208,7 @@ else: if args.resume: if not os.path.isdir(args.result_dir): - print(f"Trying to resume with a non-existing result dir {args.result_dir}.") + print(f"Trying to resume from a non-existing result dir {args.result_dir}.") exit(1) else: try: @@ -276,10 +274,6 @@ else: assert len(gpus) == 0 main_device = torch.device("cpu") -if args.dirty_debug: - args.nb_train_samples = 2500 - args.nb_test_samples = 100 - if args.physical_batch_size is None: args.physical_batch_size = args.batch_size else: @@ -720,7 +714,7 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): logits_hat_x_0 = model(x_t_with_mask) return logits_hat_x_0 @@ -745,7 +739,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None for it in range(nb_iterations_max): x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): logits = model(x_t_with_mask) logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") dist = torch.distributions.categorical.Categorical(logits=logits) @@ -894,7 +888,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for x_0, mask_generate in ae_batches( quiz_machine, @@ -910,7 +904,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): logits = logits_hat_x_0_from_random_iteration( model, x_0, mask_generate, prompt_noise=args.prompt_noise ) @@ -963,6 +957,8 @@ for i in range(args.nb_models): dropout=args.dropout, ).to(main_device) + model = torch.compile(model) + model.id = i model.test_accuracy = 0.0 model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -1333,6 +1329,8 @@ def save_models(models, suffix=""): ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): + start_time = time.perf_counter() + state = { "current_epoch": n_epoch, "c_quizzes": c_quizzes, @@ -1349,12 +1347,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- - log_string(f"{time_train=} {time_c_quizzes=}") - - if ( - min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes - and time_train >= time_c_quizzes - ): + if min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes: if c_quizzes is None: save_models(models, "naive") @@ -1362,11 +1355,16 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - start_time = time.perf_counter() + args = [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus] + + # Ugly hack: Only one thread during the first epoch so that + # compilation of the model does not explode + if n_epoch == 0: + args = args[:1] c_quizzes, agreements = multithread_execution( generate_ae_c_quizzes, - [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], + args, ) save_c_quizzes_with_scores( @@ -1385,8 +1383,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"generated_c_quizzes {c_quizzes.size()=}") - time_train = 0 - for model in models: model.test_accuracy = 0 @@ -1400,8 +1396,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] - start_time = time.perf_counter() - # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], multithread_execution( @@ -1412,8 +1406,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ], ) - time_train += int(time.perf_counter() - start_time) - # -------------------------------------------------------------------- save_models(models) -- 2.39.5