Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 20:00:33 +0000 (22:00 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 20:00:33 +0000 (22:00 +0200)
main.py

diff --git a/main.py b/main.py
index ab625cc..879d9fd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1316,11 +1316,18 @@ def c_quiz_criterion_two_certains(probas):
     return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
 
 
+def c_quiz_criterion_some(probas):
+    return ((probas >= 0.8).long().sum(dim=1) >= 1) & (
+        (probas <= 0.2).long().sum(dim=1) >= 1
+    )
+
+
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
         c_quiz_criterion_one_good_one_bad,
         c_quiz_criterion_diff,
         c_quiz_criterion_two_certains,
+        c_quiz_criterion_some,
     ]
 
     for m in models:
@@ -1336,7 +1343,9 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    duration_max = 600  # 3 * 3600
+    duration_max = 3600
+
+    wanted_nb = 512
 
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
@@ -1345,7 +1354,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
         while (
             time.perf_counter() < start_time + duration_max
-            and min([bag_len(bag) for bag in records]) < 128
+            and min([bag_len(bag) for bag in records]) < wanted_nb
         ):
             bl = [bag_len(bag) for bag in records]
             log_string(f"bag_len {bl}")
@@ -1353,39 +1362,53 @@ def generate_ae_c_quizzes(models, local_device=main_device):
             model = models[torch.randint(len(models), (1,)).item()]
             result = ae_generate(model, template, mask_generate, noise_proba)
 
-            probas = torch.cat(
-                [model_ae_proba_solutions(model, result)[:, None] for model in models],
-                dim=1,
-            )
+            to_keep = quiz_machine.problem.trivial(result) == False
+            result = result[to_keep]
 
-            for c, r in zip(criteria, records):
-                q = result[c(probas)]
-                if q.size(0) > 0:
-                    r.append(q)
+            if result.size(0) > 0:
+                probas = torch.cat(
+                    [
+                        model_ae_proba_solutions(model, result)[:, None]
+                        for model in models
+                    ],
+                    dim=1,
+                )
 
-    for n, u in enumerate(records):
-        quizzes = torch.cat(u, dim=0)[:128]
-        filename = f"culture_{n_epoch:04d}_{n:02d}.png"
+                for c, r in zip(criteria, records):
+                    q = result[c(probas)]
+                    if q.size(0) > 0:
+                        r.append(q)
 
-        # result, predicted_parts, correct_parts = bag_to_tensors(record)
+        duration = time.perf_counter() - start_time
 
-        # l = [model_ae_proba_solutions(model, result) for model in models]
-        # probas = torch.cat([x[:, None] for x in l], dim=1)
-        # comments = []
+        log_string(
+            f"generate_c_quizz_generation_speed {int(3600 * wanted_nb / duration)}/h"
+        )
 
-        # for l in probas:
-        # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+        for n, u in enumerate(records):
+            quizzes = torch.cat(u, dim=0)[:wanted_nb]
+            filename = f"culture_c_{n_epoch:04d}_{n:02d}.png"
 
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            filename,
-            quizzes=result,
-            # predicted_parts=predicted_parts,
-            # correct_parts=correct_parts,
-            # comments=comments,
-        )
+            # result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-        log_string(f"wrote {filename}")
+            l = [model_ae_proba_solutions(model, quizzes) for model in models]
+            probas = torch.cat([x[:, None] for x in l], dim=1)
+            comments = []
+
+            for l in probas:
+                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                filename,
+                quizzes=quizzes,
+                # predicted_parts=predicted_parts,
+                # correct_parts=correct_parts,
+                comments=comments,
+                nrow=8,
+            )
+
+            log_string(f"wrote {filename}")
 
 
 ######################################################################