From a7b7b6533c0aed55c861543069cf534a92df2f38 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 20 Sep 2024 21:33:20 +0200 Subject: [PATCH] Update. --- main.py | 58 +++++++++++++++++++++++------------------------------- problem.py | 4 +--- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/main.py b/main.py index 6c20d2f..961ae81 100755 --- a/main.py +++ b/main.py @@ -354,20 +354,18 @@ def samples_for_prediction_imt(input): return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -def ae_predict(model, imt_set, local_device=main_device, desc="predict"): +def ae_predict(model, imt_set, local_device=main_device): model.eval().to(local_device) record = [] - src = imt_set.split(args.eval_batch_size) - - if desc is not None: - src = tqdm.tqdm( - src, - dynamic_ncols=True, - desc=desc, - total=imt_set.size(0) // args.eval_batch_size, - ) + src = tqdm.tqdm( + imt_set.split(args.eval_batch_size), + dynamic_ncols=True, + desc="predict", + total=imt_set.size(0) // args.eval_batch_size, + delay=10, + ) for imt in src: # some paranoia @@ -383,7 +381,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full( +def predict_the_four_grids( model, input, with_noise=False, with_hints=False, local_device=main_device ): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) @@ -528,6 +526,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): dynamic_ncols=True, desc=label, total=quizzes.size(0) // batch_size, + delay=10, ): input, masks, targets = imt.unbind(dim=1) if train and nb_samples % args.batch_size == 0: @@ -633,25 +632,15 @@ def evaluate_quizzes(quizzes, models, local_device): for model in models: model = copy.deepcopy(model).to(local_device).eval() - result = predict_full( + predicted = predict_the_four_grids( model=model, input=quizzes, with_noise=False, with_hints=True, local_device=local_device, ) - - nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result) + nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted) nb_correct += (nb_mistakes == 0).long() - - # result = predict_full( - # model=model, - # input=quizzes, - # with_noise=False, - # with_hints=False, - # local_device=local_device, - # ) - nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long() to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( @@ -910,7 +899,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -c_quizzes = None +main_c_quizzes = None ###################################################################### @@ -919,7 +908,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, - "c_quizzes": c_quizzes, + "main_c_quizzes": main_c_quizzes, } filename = "state.pth" @@ -936,7 +925,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): lowest_test_accuracy = min([float(m.test_accuracy) for m in models]) if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes: - if c_quizzes is None: + if main_c_quizzes is None: save_models(models, "naive") nb_gpus = len(gpus) @@ -953,20 +942,20 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"generated_c_quizzes {new_c_quizzes.size()}") - c_quizzes = ( + main_c_quizzes = ( new_c_quizzes - if c_quizzes is None - else torch.cat([c_quizzes, new_c_quizzes]) + if main_c_quizzes is None + else torch.cat([main_c_quizzes, new_c_quizzes]) ) - c_quizzes = c_quizzes[-args.nb_train_samples :] + main_c_quizzes = main_c_quizzes[-args.nb_train_samples :] for model in models: model.test_accuracy = 0 - if c_quizzes is None: + if main_c_quizzes is None: log_string("no_c_quiz") else: - log_string(f"nb_c_quizzes {c_quizzes.size(0)}") + log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}") # -------------------------------------------------------------------- @@ -979,7 +968,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): multithread_execution( one_complete_epoch, - [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)], + [ + (model, n_epoch, main_c_quizzes, gpu) + for model, gpu in zip(weakest_models, gpus) + ], ) save_models(models) diff --git a/problem.py b/problem.py index 9bee5b2..8c1db63 100755 --- a/problem.py +++ b/problem.py @@ -45,9 +45,7 @@ class Problem: if progress_bar: with tqdm.tqdm( - total=nb, - dynamic_ncols=True, - desc="world generation", + total=nb, dynamic_ncols=True, desc="world generation", delay=10 ) as pbar: while n < nb: q = self.queue.get(block=True) -- 2.39.5