Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 16:29:42 +0000 (18:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 16:29:42 +0000 (18:29 +0200)
main.py

diff --git a/main.py b/main.py
index e533802..b87518e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -51,7 +51,7 @@ parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--physical_batch_size", type=int, default=None)
 
-parser.add_argument("--inference_batch_size", type=int, default=25)
+parser.add_argument("--inference_batch_size", type=int, default=50)
 
 parser.add_argument("--nb_train_samples", type=int, default=40000)
 
@@ -61,7 +61,7 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
-parser.add_argument("--c_quiz_multiplier", type=int, default=1)
+parser.add_argument("--c_quiz_multiplier", type=int, default=4)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
@@ -973,7 +973,10 @@ def ae_batches(
     c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
 
     full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-        nb, c_quiz_bags, data_structures=data_structures
+        nb,
+        c_quiz_bags,
+        data_structures=data_structures,
+        c_quiz_multiplier=args.c_quiz_multiplier,
     )
 
     src = zip(
@@ -1052,14 +1055,14 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50):
         update = (1 - mask_to_change) * input + mask_to_change * final
 
         if update.equal(input):
-            log_string(f"exit after {it+1} iterations")
+            log_string(f"exit after {it+1} iterations")
             break
         else:
             changed = changed & (update != input).max(dim=1).values
             input[changed] = update[changed]
 
-    if it == nb_iterations_max:
-        log_string(f"remains {changed.long().sum()}")
+    if it == nb_iterations_max:
+    # log_string(f"remains {changed.long().sum()}")
 
     return input
 
@@ -1348,7 +1351,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
     quad_order = ("A", "f_A", "B", "f_B")
 
     template = quiz_machine.problem.create_empty_quizzes(
-        nb=args.batch_size, quad_order=quad_order
+        nb=args.inference_batch_size, quad_order=quad_order
     ).to(local_device)
 
     mask_generate = quiz_machine.make_quiz_mask(
@@ -1357,15 +1360,16 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    # wanted_nb = 240
-    # nb_to_save = 240
-
-    wanted_nb = args.nb_train_samples // 4
+    wanted_nb = 128
     nb_to_save = 128
 
+    # wanted_nb = args.nb_train_samples // args.c_quiz_multiplier
+    # nb_to_save = 256
+
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
 
+        last_log = -1
         start_time = time.perf_counter()
 
         while (
@@ -1374,8 +1378,6 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         ):
             model = models[torch.randint(len(models), (1,)).item()]
             result = ae_generate(model, template, mask_generate)
-            bl = [bag_len(bag) for bag in records]
-            log_string(f"bag_len {bl} model {model.id}")
 
             to_keep = quiz_machine.problem.trivial(result) == False
             result = result[to_keep]
@@ -1394,13 +1396,34 @@ def generate_ae_c_quizzes(models, local_device=main_device):
                     if q.size(0) > 0:
                         r.append(q)
 
+            duration = time.perf_counter() - start_time
+            nb_generated = min([bag_len(bag) for bag in records])
+
+            if last_log < 0 or duration > last_log + 60:
+                last_log = duration
+                if nb_generated > 0:
+                    if nb_generated < wanted_nb:
+                        d = (wanted_nb - nb_generated) * duration / nb_generated
+                        e = (
+                            datetime.datetime.now() + datetime.timedelta(seconds=d)
+                        ).strftime("%a %H:%M")
+                    else:
+                        e = "now!"
+                else:
+                    e = "???"
+
+                bl = [bag_len(bag) for bag in records]
+                log_string(
+                    f"bag_len {bl} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
+                )
+
         duration = time.perf_counter() - start_time
 
         log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
 
         for n, u in enumerate(records):
             quizzes = torch.cat(u, dim=0)[:nb_to_save]
-            filename = f"culture_c_{n_epoch:04d}_{n:02d}.png"
+            filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png"
 
             # result, predicted_parts, correct_parts = bag_to_tensors(record)
 
@@ -1419,7 +1442,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
                 # correct_parts=correct_parts,
                 comments=comments,
                 delta=True,
-                nrow=12,
+                nrow=8,
             )
 
             log_string(f"wrote {filename}")
@@ -1475,6 +1498,9 @@ last_n_epoch_c_quizzes = 0
 
 c_quizzes = None
 
+time_c_quizzes = 0
+time_train = 0
+
 for n_epoch in range(current_epoch, args.nb_epochs):
     start_time = time.perf_counter()
 
@@ -1501,10 +1527,13 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     if (
         n_epoch >= 200
         and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
-        and n_epoch >= last_n_epoch_c_quizzes + 10
+        and time_train >= time_c_quizzes
     ):
         last_n_epoch_c_quizzes = n_epoch
+        start_time = time.perf_counter()
         c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)
+        time_c_quizzes = time.perf_counter() - start_time
+        time_train = 0
 
     if c_quizzes is None:
         log_string("no_c_quiz")
@@ -1534,6 +1563,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
+    time_train += time.perf_counter() - start_time
+
     # --------------------------------------------------------------------
 
     for model in models: