Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:15:27 +0000 (09:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:15:27 +0000 (09:15 +0200)
main.py

diff --git a/main.py b/main.py
index 5f80fb5..05bb108 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -336,17 +336,6 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
-
-def bag_len(bag):
-    return sum([x.size(0) for x in bag])
-
-
-def bag_to_tensors(bag):
-    return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0])))
-
-
-######################################################################
-
 # If we need to move an optimizer to a different device
 
 
@@ -651,22 +640,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     )
 
 
-######################################################################
-
-
-class TokenCat(nn.Module):
-    def __init__(self, m, n):
-        super().__init__()
-        self.m = m
-        self.n = n
-
-    def forward(self, x):
-        u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1)
-        u = self.m(u)
-        u = u[:, self.n :]
-        return u
-
-
 ######################################################################
 
 import attae
@@ -701,186 +674,13 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def quiz_validation_(
-    models,
-    c_quizzes,
-    local_device,
-    nb_have_to_be_correct,
-    nb_have_to_be_wrong,
-    nb_mistakes_to_be_wrong,
-    nb_hints,
-    nb_runs=1,
-):
-    ######################################################################
-    # If too many with process per-batch
-
-    if c_quizzes.size(0) > args.inference_batch_size:
-        record = []
-        for q, nh in zip(
-            c_quizzes.split(args.inference_batch_size),
-            nb_hints.split(args.inference_batch_size),
-        ):
-            record.append(
-                quiz_validation(
-                    models=models,
-                    c_quizzes=q,
-                    local_device=local_device,
-                    nb_have_to_be_correct=nb_have_to_be_correct,
-                    nb_have_to_be_wrong=nb_have_to_be_wrong,
-                    nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
-                    nb_hints=nh,
-                    nb_runs=nb_runs,
-                )
-            )
-
-        r = []
-        for k in range(len(record[0])):
-            r.append(torch.cat([x[k] for x in record], dim=0))
-
-        return tuple(r)
-    ######################################################################
-
-    record_wrong = []
-    nb_correct, nb_wrong = 0, 0
-
-    for i, model in enumerate(models):
-        assert i == model.id  # a bit of paranoia
-        model = copy.deepcopy(model).to(local_device).eval()
-        correct, wrong = True, False
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=c_quizzes,
-                quad_order=("A", "f_A", "B", "f_B"),
-                quad_mask=quad,
-            )
-
-            sub_correct, sub_wrong = False, True
-            for _ in range(nb_runs):
-                result = diffuser.generate(
-                    model=model,
-                    x_0=c_quizzes,
-                    mask_generate=mask_generate,
-                    nb_hints=nb_hints,
-                )
-
-                nb_mistakes = (result != c_quizzes).long().sum(dim=1)
-                sub_correct = sub_correct | (nb_mistakes == 0)
-                sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong)
-
-            correct = correct & sub_correct
-            wrong = wrong | sub_wrong
-
-        record_wrong.append(wrong[:, None])
-        nb_correct += correct.long()
-        nb_wrong += wrong.long()
-
-    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
-    wrong = torch.cat(record_wrong, dim=1)
-
-    return to_keep, nb_correct, nb_wrong, wrong
-
-
-######################################################################
-
-
-def generate_c_quizzes_(models, nb, local_device=main_device):
-    # To be thread-safe we must make copies
-
-    def copy_for_inference(model):
-        return copy.deepcopy(model).to(local_device).eval()
-
-    quad_order = ("A", "f_A", "B", "f_B")
-
-    template = quiz_machine.problem.create_empty_quizzes(
-        nb=args.inference_batch_size, quad_order=quad_order
-    ).to(local_device)
-
-    wanted_nb = nb
-    nb_to_save = 256
-    nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
-
-    with torch.autograd.no_grad():
-        record_c_quizzes, record_agreements = [], []
-
-        last_log = -1
-        start_time = time.perf_counter()
-
-        while nb_c_quizzes_per_model.min() < wanted_nb:
-            model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
-            generator_id = model.id
-
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
-            )
-
-            c_quizzes = diffuser.generate(model, template, mask_generate)
-
-            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-            c_quizzes = c_quizzes[to_keep]
-
-            nb_hints = torch.full(
-                (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device
-            )
-
-            if c_quizzes.size(0) > 0:
-                to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
-                    models,
-                    c_quizzes,
-                    local_device,
-                    nb_have_to_be_correct=args.nb_have_to_be_correct,
-                    nb_have_to_be_wrong=args.nb_have_to_be_wrong,
-                    nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
-                    nb_hints=nb_hints,
-                    nb_runs=args.nb_runs,
-                )
-
-                # to_keep[...]=True
-
-                q = c_quizzes[to_keep]
-
-                if q.size(0) > 0:
-                    record_c_quizzes.append(q)
-                    a = (record_wrong == False)[to_keep]
-                    record_agreements.append(a)
-                    nb_c_quizzes_per_model += a.long().sum(dim=0)
-
-            duration = time.perf_counter() - start_time
-            nb_generated = nb_c_quizzes_per_model.min().item()
-
-            if last_log < 0 or duration > last_log + 5:
-                last_log = duration
-                if nb_generated > 0:
-                    if nb_generated < wanted_nb:
-                        d = (wanted_nb - nb_generated) * duration / nb_generated
-                        e = (
-                            datetime.datetime.now() + datetime.timedelta(seconds=d)
-                        ).strftime("%a %H:%M")
-                    else:
-                        e = "now!"
-                else:
-                    e = "???"
-
-                log_string(
-                    f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
-                )
-
-        duration = time.perf_counter() - start_time
-
-        log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
-
-        c_quizzes = torch.cat(record_c_quizzes, dim=0)
-        agreements = torch.cat(record_agreements, dim=0)
-
-    return c_quizzes.to("cpu"), agreements.to("cpu")
-
-
-######################################################################
-
-
 def generate_c_quizzes(models, nb, local_device=main_device):
     record = []
     nb_validated = 0
+
+    start_time = time.perf_counter()
+    last_log = -1
+
     while nb_validated < nb:
         model = models[torch.randint(len(models), (1,)).item()]
         model = copy.deepcopy(model).to(local_device).eval()
@@ -911,6 +711,33 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 
         log_string(f"generate_c_quizzes {nb_validated}")
 
+        #####################
+
+        duration = time.perf_counter() - start_time
+
+        if last_log < 0 or duration > last_log + 10:
+            last_log = duration
+            if nb_validated > 0:
+                if nb_validated < wanted_nb:
+                    d = (wanted_nb - nb_validated) * duration / nb_validated
+                    e = (
+                        datetime.datetime.now() + datetime.timedelta(seconds=d)
+                    ).strftime("%a %H:%M")
+                else:
+                    e = "now!"
+            else:
+                e = "???"
+
+            log_string(
+                f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+            )
+
+        #####################
+
+    duration = time.perf_counter() - start_time
+
+    log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
+
     return torch.cat(record)