Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 09:31:48 +0000 (11:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 09:31:48 +0000 (11:31 +0200)
grids.py
problem.py
quiz_machine.py

index e64cb33..131f85c 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -118,80 +118,49 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
-    def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"):
+    def check_structure(self, quizzes, struct):
         S = self.height * self.width
 
-        assert (
-            (
-                (quizzes[:, 0] == self.token_forward)
-                | (quizzes[:, 0] == self.token_backward)
-            )
-            & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
-            & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
-            & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+        return (
+            (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0]))
+            & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1]))
+            & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2]))
+            & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3]))
         ).all()
 
-        T = torch.arange(quizzes.size(1), device=quizzes.device)
-
-        if shape == "fwd_3_bck_123":
-            forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long()
-        elif shape == "fwd_012_bck_0":
-            forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long()
-        elif shape == "fwd_3_bck_3":
-            forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
-        else:
-            raise ValueError(shape)
-
-        is_forward = (quizzes[:, 0] == self.token_forward).long()
+    def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+        assert check_structure(quizzes, struct)
 
-        return (
-            is_forward[:, None] * forward_mask[None, :]
-            + (1 - is_forward)[:, None] * backward_mask[None, :]
-        )
+        ar_mask = quizzes.new_zeros(quizzes.size())
 
-    def p_a_flip(self, quizzes, pairwise_flip=False):
-        S = self.height * self.width
+        a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+        a[:, 0, :] = mask[0]
+        a[:, 1, :] = mask[1]
+        a[:, 2, :] = mask[2]
+        a[:, 3, :] = mask[3]
 
-        assert (
-            (
-                (quizzes[:, 0] == self.token_forward)
-                | (quizzes[:, 0] == self.token_backward)
-            )
-            & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
-            & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
-            & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
-        ).all()
+        return ar_mask
 
-        if pairwise_flip:
-            flipped = torch.cat(
-                [
-                    quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
-                    quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
-                    quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
-                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
-                ],
-                dim=1,
-            )
-        else:
-            flipped_from_forward = torch.cat(
-                [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]],
-                dim=1,
-            )
-            flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
+    def reconfigure(
+        self,
+        quizzes,
+        struct_from=("A", "f_A", "B", "f_B"),
+        struct_to=("f_B", "A", "f_A", "B"),
+    ):
+        assert check_structure(quizzes, struct_from)
 
-            flipped_from_backward = torch.cat(
-                [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1
-            )
-            flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
+        sf = dict((l, n) for n, l in enumerate(struct_from))
 
-            m = (quizzes[:, 0] == self.token_forward).long()[:, None]
+        result = quizzes.new(quizzes.size())
+        q = quizzes.reshape(-1, 4, 4 * (S + 1))
+        r = reshape.reshape(-1, 4, 4 * (S + 1))
 
-            flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
+        r[:, 0, :] = q[:, sf[struct_to[0]]]
+        r[:, 1, :] = q[:, sf[struct_to[1]]]
+        r[:, 2, :] = q[:, sf[struct_to[2]]]
+        r[:, 3, :] = q[:, sf[struct_to[3]]]
 
-        return flipped
+        return result
 
     def __init__(
         self,
@@ -201,8 +170,20 @@ class Grids(problem.Problem):
         tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
-        self.token_forward = len(self.colors)
-        self.token_backward = self.token_forward + 1
+
+        self.token_A = len(self.colors)
+        self.token_f_A = self.token_A + 1
+        self.token_B = self.token_f_A + 1
+        self.token_f_B = self.token_B + 1
+        self.l2tok = {
+            "A": self.token_A,
+            "f_A": self.token_f_A,
+            "B": self.token_B,
+            "f_B": self.token_f_B,
+        }
+
+        self.nb_token_values = self.token_f_B + 1
+
         self.height = 10
         self.width = 10
         self.cache_rec_coo = {}
@@ -237,8 +218,7 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, self.width)
+    def grid2img(self, x, scale=15):
         m = torch.logical_and(x >= 0, x < len(self.colors)).long()
         y = self.colors[x * m].permute(0, 3, 1, 2)
         s = y.shape
@@ -247,154 +227,95 @@ class Grids(problem.Problem):
 
         y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
         y[:, :, torch.arange(0, y.size(2), scale), :] = 0
-        y = y[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
-                    if x[n, i, j] == self.token_forward:
-                        for k in range(2, scale - 2):
-                            y[
-                                n,
-                                :,
-                                i * scale + k,
-                                j * scale + scale - 5 - abs(k - scale // 2),
-                            ] = 0
-
-                    elif x[n, i, j] == self.token_backward:
-                        for k in range(2, scale - 2):
-                            y[
-                                n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2)
-                            ] = 0
-                            # y[n, :, i * scale + k, j * scale + k - l] = 0
-                            # y[
-                            # n, :, i * scale + scale - 1 - k, j * scale + k - l
-                            # ] = 0
+                    if m[n, i, j] == 0:
+                        for k in range(3, scale - 2):
+                            y[n, :, i * scale + k, j * scale + k] = 0
+                            y[n, :, i * scale + k, j * scale + scale - k] = 0
+
+        y = y[:, :, 1:, 1:]
 
         return y
 
-    def save_image(
+    def add_frame(self, img, colors, thickness):
+        result = img.new(
+            img.size(0),
+            img.size(1),
+            img.size(2) + 2 * thickness,
+            img.size(3) + 2 * thickness,
+        )
+
+        result[...] = colors[:, :, None, None]
+        result[:, :, thickness:-thickness, thickness:-thickness] = img
+
+        return result
+
+    def save_quizzes_as_image(
         self,
         result_dir,
         filename,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
+        quizzes,
+        predicted_parts=None,
+        correct_parts=None,
         nrow=4,
         margin=8,
     ):
         S = self.height * self.width
-        As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view(
-            -1, self.height, self.width
-        )
-        f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view(
-            -1, self.height, self.width
-        )
-        Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
-            -1, self.height, self.width
+
+        A, f_A, B, f_B = (
+            quizzes.reshape(-1, 4, S + 1)[:, :, 1:]
+            .reshape(-1, 4, self.height, self.width)
+            .permute(1, 0, 2, 3)
         )
-        prompts = torch.cat([As, f_As, Bs], dim=2)
-        answers = answers[:, 1 : S + 1].reshape(
-            answers.size(0), self.height, self.width
+
+        black, white, gray, green, red = torch.tensor(
+            [[0, 0, 0], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
+            device=quizzes.device,
         )
 
-        if predicted_prompts is None:
-            predicted_prompts = 255
+        img_A = self.add_frame(self.grid2img(A), black[None, :], thickness=1)
+        img_f_A = self.add_frame(self.grid2img(f_A), black[None, :], thickness=1)
+        img_B = self.add_frame(self.grid2img(B), black[None, :], thickness=1)
+        img_f_B = self.add_frame(self.grid2img(f_B), black[None, :], thickness=1)
 
-        if predicted_answers is None:
-            predicted_answers = 255
+        # predicted_parts Nx4
+        # correct_parts Nx4
 
-        def add_frame(x, c, margin, bottom=False):
-            if bottom:
-                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
-            else:
-                h, w, di, dj = (
-                    x.size(2) + 2 * margin,
-                    x.size(3) + 2 * margin,
-                    margin,
-                    margin,
+        if predicted_parts is None:
+            colors = white[None, None, :].expand(-1, 4, -1)
+        else:
+            if correct_parts is None:
+                colors = (
+                    predicted_parts[:, :, None] * gray[None, None, :]
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
-
-            y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
-            if type(c) is int:
-                y[...] = c
             else:
-                c = c.long()[:, None]
-                c = (
-                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
-                    * torch.tensor([64, 64, 64])
-                    + (c == 1).long() * torch.tensor([0, 255, 0])
-                    + (c == 0).long() * torch.tensor([255, 255, 255])
-                    + (c == -1).long() * torch.tensor([255, 0, 0])
-                )
-                y[...] = c[:, :, None, None]
-
-            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
-            return y
-
-        img_prompts = torch.cat(
-            [
-                add_frame(
-                    add_frame(self.frame2img(x), c=0, margin=1),
-                    c=predicted_prompts,
-                    margin=margin,
+                colors = (
+                    predicted_parts[:, :, None]
+                    * (
+                        correct_parts[:, :, None] * green[None, None, :]
+                        + (1 - correct_parts[:, :, None]) * red[None, None, :]
+                    )
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
-                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
-            ],
-            dim=3,
-        )
-
-        h = img_prompts.size(2)
-        img_answers = add_frame(
-            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
-            c=predicted_answers,
-            margin=margin,
-        )
-
-        separator_size = 2 * margin
-
-        separator = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
-
-        marker = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
 
-        # marker[:, :, 0] = 0
-        # marker[:, :, h - 1] = 0
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=6)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=6)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6)
 
-        for k in range(1, 2 * separator_size - 8):
-            i = k - (separator_size - 4)
-            j = separator_size - 5 - abs(i)
-            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
-            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+        img_A = self.add_frame(img_A, white[None, :], thickness=2)
+        img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+        img_B = self.add_frame(img_B, white[None, :], thickness=2)
+        img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
-        img = torch.cat(
-            [
-                img_prompts,
-                marker,
-                img_answers,
-            ],
-            dim=3,
-        )
+        img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
 
         image_name = os.path.join(result_dir, filename)
+
         torchvision.utils.save_image(
             img.float() / 255.0,
             image_name,
@@ -405,9 +326,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def nb_token_values(self):
-        return len(self.colors) + 2
-
     # @torch.compile
     def rec_coo(
         self,
@@ -1444,70 +1362,36 @@ class Grids(problem.Problem):
         f_Bs = answers[:, 1:]
         return (Bs == f_Bs).long().min(dim=-1).values > 0
 
-    def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
+    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
             tasks = self.all_tasks
 
         S = self.height * self.width
-        prompts = torch.full((nb, 3 * S + 3), self.token_forward)
-        answers = torch.full((nb, S + 1), self.token_forward)
-
-        bunch = zip(prompts, answers)
+        quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64)
 
         if progress_bar:
-            bunch = tqdm.tqdm(
-                bunch,
+            quizzes = tqdm.tqdm(
+                quizzes,
                 dynamic_ncols=True,
-                desc="world generation",
+                desc="world quizzes generation",
                 total=prompts.size(0),
             )
 
-        for prompt, answer in bunch:
-            A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view(
-                self.height, self.width
-            )
-            f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view(
-                self.height, self.width
-            )
-            B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
-                self.height, self.width
-            )
-            f_B = answer[1 : S + 1].view(self.height, self.width)
+        for quiz in quizzes:
+            q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
+            q[...] = 0
+            A, f_A, B, f_B = q
             task = tasks[torch.randint(len(tasks), (1,)).item()]
-            A[...] = 0
-            f_A[...] = 0
-            B[...] = 0
-            f_B[...] = 0
             task(A, f_A, B, f_B)
 
-        return prompts.flatten(1), answers.flatten(1)
-
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-        nrow=4,
-    ):
-        self.save_image(
-            result_dir,
-            filename_prefix + ".png",
-            prompts,
-            answers,
-            predicted_prompts,
-            predicted_answers,
-            nrow,
-        )
+        return quizzes
 
     def save_some_examples(self, result_dir):
         nb, nrow = 128, 4
         for t in self.all_tasks:
             print(t.__name__)
-            prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
-            self.save_quiz_illustrations(
+            prompts, answers = self.generate_w_quizzes_(nb, tasks=[t])
+            self.save_quizzes_as_image(
                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
             )
 
@@ -1526,7 +1410,7 @@ if __name__ == "__main__":
     # )
     #    time.sleep(10)
     # start_time = time.perf_counter()
-    # prompts, answers = grids.generate_prompts_and_answers(nb)
+    # prompts, answers = grids.generate_w_quizzes(nb)
     # delay = time.perf_counter() - start_time
     # print(f"{prompts.size(0)/delay:02f} seq/s")
     # exit(0)
@@ -1536,13 +1420,19 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_reconfigure]:
+    for t in [grids.task_replace_color]:
         # for t in [grids.task_symbols]:
         print(t.__name__)
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
-        grids.save_quiz_illustrations(
-            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+        predicted_parts = quizzes.new_zeros(quizzes.size(0), 4)
+        predicted_parts[:, 3] = 1
+        correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device)
+        grids.save_quizzes_as_image(
+            "/tmp",
+            t.__name__ + ".png",
+            quizzes,
+            predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
         )
 
     exit(0)
@@ -1552,7 +1442,7 @@ if __name__ == "__main__":
     for t in grids.all_tasks:
         # for t in [grids.task_compute]:
         start_time = time.perf_counter()
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        prompts, answers = grids.generate_w_quizzes_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
 
index 05f3b20..61e4834 100755 (executable)
@@ -32,7 +32,7 @@ class Problem:
         pass
 
     # The one to implement, returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers_(self, nb):
+    def generate_w_quizzes_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -49,13 +49,13 @@ class Problem:
 
     def fill_cache(self):
         while True:
-            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
+            prompts, answers = self.generate_w_quizzes_(self.chunk_size)
 
             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
 
-    def generate_prompts_and_answers(self, nb):
+    def generate_w_quizzes(self, nb):
         if self.queue is None:
-            return self.generate_prompts_and_answers_(nb)
+            return self.generate_w_quizzes_(nb)
 
         if self.rest is not None:
             prompts, answers = rest
index e70b903..d62ba3b 100755 (executable)
@@ -599,136 +599,3 @@ class QuizMachine:
         return c_quizzes.to("cpu")
 
     ######################################################################
-
-    def generate_c_quizzes_mixing(
-        self,
-        nb,
-        model_for_generation,
-        p2a_only=False,
-        temperature_hot=1.0,
-        temperature_cold=1.0,
-    ):
-        c_quizzes = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        c_quizzes_1 = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        c_quizzes_2 = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        lt_noisy = lambda s, logits: logits / temperature_hot
-        lt_clean = lambda s, logits: logits / temperature_cold
-
-        ######################################################################
-
-        c_quizzes_1[...] = self.problem.token_backward
-        ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes_1,
-            ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
-
-        c_quizzes_2[...] = self.problem.token_backward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes_2,
-            ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_noisy,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
-
-        h = len(model_for_generation.trunk) // 2
-
-        with torch.autograd.no_grad():
-            t = model_for_generation.training
-            model_for_generation.eval()
-
-            bs1 = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(c_quizzes_1), end_layer=h
-            )
-            bs2 = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(c_quizzes_2), end_layer=h
-            )
-
-            alpha = 0.5
-
-            output = model_for_generation.partial_forward(
-                mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
-                start_layer=h,
-            ).x
-
-            dist = torch.distributions.categorical.Categorical(logits=output)
-            c_quizzes[...] = dist.sample()
-
-            c_quizzes[...] = (
-                ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
-            )
-
-            model_for_generation.train(t)
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
-
-        ######################################################################
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
-
-        c_quizzes = self.problem.p_a_flip(c_quizzes)
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
-            seq_logproba=seq_logproba,
-            logit_transformer=lt_clean,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
-
-        print("DONE")
-        exit(0)
-
-        return c_quizzes.to("cpu")