Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 16:11:35 +0000 (18:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 16:11:35 +0000 (18:11 +0200)
quizz_machine.py
sky.py

index a3da365..28b94d1 100755 (executable)
@@ -70,15 +70,6 @@ import sky
 
 
 class QuizzMachine:
-    def save_image(self, input, result_dir, filename, logger):
-        img = self.sky.seq2img(input.to("cpu"))
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-        logger(f"wrote {image_name}")
-
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
     def make_ar_mask(self, input):
         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
         return b.long()[None, :].expand_as(input)
@@ -94,12 +85,12 @@ class QuizzMachine:
     ):
         super().__init__()
 
-        self.sky = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+        self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
         self.batch_size = batch_size
         self.device = device
 
-        self.train_w_quizzes = self.sky.generate_seq(nb_train_samples).to(device)
-        self.test_w_quizzes = self.sky.generate_seq(nb_test_samples).to(device)
+        self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device)
+        self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device)
 
         self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
 
@@ -107,7 +98,7 @@ class QuizzMachine:
         self.test_c_quizzes = []
 
         if result_dir is not None:
-            self.save_quizzes(
+            self.problem.save_quizzes(
                 self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
             )
 
@@ -215,7 +206,7 @@ class QuizzMachine:
             device=self.device,
         )
 
-        self.save_quizzes(
+        self.problem.save_quizzes(
             result[:72],
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
@@ -228,7 +219,7 @@ class QuizzMachine:
         input = self.train_w_quizzes if for_train else self.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
-        input[-nb:] = self.sky.generate_seq(nb).to(self.device)
+        input[-nb:] = self.problem.generate_seq(nb).to(self.device)
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
@@ -298,11 +289,13 @@ class QuizzMachine:
         ###############################################################
         # Create the reverse quizzes
 
+        token_forward, token_backward = self.problem.direction_tokens()
+
         l = (c_quizzes.size(1) - 1) // 2
         direction = c_quizzes[:, l : l + 1]
-        direction = self.sky.token_forward * (
-            direction == self.sky.token_backward
-        ) + self.sky.token_backward * (direction == self.sky.token_forward)
+        direction = self.problem.token_forward * (
+            direction == self.problem.token_backward
+        ) + self.problem.token_backward * (direction == self.problem.token_forward)
         reverse_c_quizzes = torch.cat(
             [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
         )
diff --git a/sky.py b/sky.py
index a90e37d..cb25ea0 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm
+import math, sys, tqdm, os
 
 import torch, torchvision
 
@@ -15,6 +15,17 @@ from torch.nn import functional as F
 ######################################################################
 
 
+class Problem:
+    def generate_seq(self, nb_train_samples):
+        pass
+
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        pass
+
+    def direction_tokens(self):
+        pass
+
+
 class Sky:
     colors = torch.tensor(
         [
@@ -48,6 +59,9 @@ class Sky:
         self.nb_birds = nb_birds
         self.nb_iterations = nb_iterations
 
+    def direction_tokens(self):
+        return self.token_forward, self.token_backward
+
     def generate_seq(self, nb, return_iterations=False):
         pairs = []
         kept_iterations = []
@@ -338,6 +352,15 @@ class Sky:
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
+    def save_image(self, input, result_dir, filename, logger):
+        img = self.seq2img(input.to("cpu"))
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+        logger(f"wrote {image_name}")
+
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
 
 ######################################################################