Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 19:33:20 +0000 (21:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 19:33:20 +0000 (21:33 +0200)
main.py
problem.py

diff --git a/main.py b/main.py
index 6c20d2f..961ae81 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -354,20 +354,18 @@ def samples_for_prediction_imt(input):
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
+def ae_predict(model, imt_set, local_device=main_device):
     model.eval().to(local_device)
 
     record = []
 
-    src = imt_set.split(args.eval_batch_size)
-
-    if desc is not None:
-        src = tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc=desc,
-            total=imt_set.size(0) // args.eval_batch_size,
-        )
+    src = tqdm.tqdm(
+        imt_set.split(args.eval_batch_size),
+        dynamic_ncols=True,
+        desc="predict",
+        total=imt_set.size(0) // args.eval_batch_size,
+        delay=10,
+    )
 
     for imt in src:
         # some paranoia
@@ -383,7 +381,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(
+def predict_the_four_grids(
     model, input, with_noise=False, with_hints=False, local_device=main_device
 ):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
@@ -528,6 +526,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
         dynamic_ncols=True,
         desc=label,
         total=quizzes.size(0) // batch_size,
+        delay=10,
     ):
         input, masks, targets = imt.unbind(dim=1)
         if train and nb_samples % args.batch_size == 0:
@@ -633,25 +632,15 @@ def evaluate_quizzes(quizzes, models, local_device):
 
     for model in models:
         model = copy.deepcopy(model).to(local_device).eval()
-        result = predict_full(
+        predicted = predict_the_four_grids(
             model=model,
             input=quizzes,
             with_noise=False,
             with_hints=True,
             local_device=local_device,
         )
-
-        nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result)
+        nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
         nb_correct += (nb_mistakes == 0).long()
-
-        # result = predict_full(
-        # model=model,
-        # input=quizzes,
-        # with_noise=False,
-        # with_hints=False,
-        # local_device=local_device,
-        # )
-
         nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
 
     to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
@@ -910,7 +899,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-c_quizzes = None
+main_c_quizzes = None
 
 ######################################################################
 
@@ -919,7 +908,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     state = {
         "current_epoch": n_epoch,
-        "c_quizzes": c_quizzes,
+        "main_c_quizzes": main_c_quizzes,
     }
 
     filename = "state.pth"
@@ -936,7 +925,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
 
     if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
-        if c_quizzes is None:
+        if main_c_quizzes is None:
             save_models(models, "naive")
 
         nb_gpus = len(gpus)
@@ -953,20 +942,20 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
 
-        c_quizzes = (
+        main_c_quizzes = (
             new_c_quizzes
-            if c_quizzes is None
-            else torch.cat([c_quizzes, new_c_quizzes])
+            if main_c_quizzes is None
+            else torch.cat([main_c_quizzes, new_c_quizzes])
         )
-        c_quizzes = c_quizzes[-args.nb_train_samples :]
+        main_c_quizzes = main_c_quizzes[-args.nb_train_samples :]
 
         for model in models:
             model.test_accuracy = 0
 
-    if c_quizzes is None:
+    if main_c_quizzes is None:
         log_string("no_c_quiz")
     else:
-        log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
+        log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}")
 
     # --------------------------------------------------------------------
 
@@ -979,7 +968,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     multithread_execution(
         one_complete_epoch,
-        [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
+        [
+            (model, n_epoch, main_c_quizzes, gpu)
+            for model, gpu in zip(weakest_models, gpus)
+        ],
     )
 
     save_models(models)
index 9bee5b2..8c1db63 100755 (executable)
@@ -45,9 +45,7 @@ class Problem:
 
         if progress_bar:
             with tqdm.tqdm(
-                total=nb,
-                dynamic_ncols=True,
-                desc="world generation",
+                total=nb, dynamic_ncols=True, desc="world generation", delay=10
             ) as pbar:
                 while n < nb:
                     q = self.queue.get(block=True)