Update.
[culture.git] / quizz_machine.py
index daa8a54..d63855c 100755 (executable)
@@ -67,40 +67,14 @@ def masked_inplace_autoregression(
 ######################################################################
 
 
 ######################################################################
 
 
-class Task:
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        pass
-
-    def vocabulary_size(self):
-        pass
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        pass
-
-
-######################################################################
-
-import sky
-
-
-class QuizzMachine(Task):
-    def save_image(self, input, result_dir, filename, logger):
-        img = sky.seq2img(input.to("cpu"), self.height, self.width)
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-        logger(f"wrote {image_name}")
-
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
+class QuizzMachine:
     def make_ar_mask(self, input):
         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
         return b.long()[None, :].expand_as(input)
 
     def __init__(
         self,
     def make_ar_mask(self, input):
         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
         return b.long()[None, :].expand_as(input)
 
     def __init__(
         self,
+        problem,
         nb_train_samples,
         nb_test_samples,
         batch_size,
         nb_train_samples,
         nb_test_samples,
         batch_size,
@@ -110,18 +84,12 @@ class QuizzMachine(Task):
     ):
         super().__init__()
 
     ):
         super().__init__()
 
+        self.problem = problem
         self.batch_size = batch_size
         self.device = device
         self.batch_size = batch_size
         self.device = device
-        self.height = 6
-        self.width = 8
 
 
-        self.train_w_quizzes = sky.generate_seq(
-            nb_train_samples, height=self.height, width=self.width
-        ).to(device)
-
-        self.test_w_quizzes = sky.generate_seq(
-            nb_test_samples, height=self.height, width=self.width
-        ).to(device)
+        self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device)
+        self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device)
 
         self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
 
 
         self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
 
@@ -129,7 +97,7 @@ class QuizzMachine(Task):
         self.test_c_quizzes = []
 
         if result_dir is not None:
         self.test_c_quizzes = []
 
         if result_dir is not None:
-            self.save_quizzes(
+            self.problem.save_quizzes(
                 self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
             )
 
                 self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
             )
 
@@ -237,7 +205,7 @@ class QuizzMachine(Task):
             device=self.device,
         )
 
             device=self.device,
         )
 
-        self.save_quizzes(
+        self.problem.save_quizzes(
             result[:72],
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             result[:72],
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
@@ -250,9 +218,7 @@ class QuizzMachine(Task):
         input = self.train_w_quizzes if for_train else self.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         input = self.train_w_quizzes if for_train else self.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
-        input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to(
-            self.device
-        )
+        input[-nb:] = self.problem.generate_seq(nb).to(self.device)
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
@@ -274,7 +240,7 @@ class QuizzMachine(Task):
         # Generate quizzes with model
 
         c_quizzes = torch.empty(
         # Generate quizzes with model
 
         c_quizzes = torch.empty(
-            nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
+            nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
         ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
         )
 
         ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
@@ -300,17 +266,15 @@ class QuizzMachine(Task):
 
             ave_seq_logproba = seq_logproba.mean()
 
 
             ave_seq_logproba = seq_logproba.mean()
 
-            logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
             if min_ave_seq_logproba is None:
                 break
 
             # Oh man that's ugly
             if min_ave_seq_logproba is None:
                 break
 
             # Oh man that's ugly
-            if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+            if ave_seq_logproba < min_ave_seq_logproba:
                 if d_temperature > 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
                 if d_temperature > 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
-            elif ave_seq_logproba > min_ave_seq_logproba:
+            elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
                 if d_temperature < 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
                 if d_temperature < 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
@@ -322,11 +286,13 @@ class QuizzMachine(Task):
         ###############################################################
         # Create the reverse quizzes
 
         ###############################################################
         # Create the reverse quizzes
 
-        l = self.height * self.width
+        token_forward, token_backward = self.problem.direction_tokens()
+
+        l = (c_quizzes.size(1) - 1) // 2
         direction = c_quizzes[:, l : l + 1]
         direction = c_quizzes[:, l : l + 1]
-        direction = sky.token_forward * (
-            direction == sky.token_backward
-        ) + sky.token_backward * (direction == sky.token_forward)
+        direction = self.problem.token_forward * (
+            direction == self.problem.token_backward
+        ) + self.problem.token_backward * (direction == self.problem.token_forward)
         reverse_c_quizzes = torch.cat(
             [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
         )
         reverse_c_quizzes = torch.cat(
             [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
         )