Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index cb25ea0..ac6cbdc 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -14,19 +14,10 @@ from torch.nn import functional as F
 
 ######################################################################
 
+import problem
 
-class Problem:
-    def generate_seq(self, nb_train_samples):
-        pass
 
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        pass
-
-    def direction_tokens(self):
-        pass
-
-
-class Sky:
+class Sky(problem.Problem):
     colors = torch.tensor(
         [
             [255, 255, 255],
@@ -267,31 +258,34 @@ class Sky:
 
         return torch.cat(result, dim=0)
 
-    def frame2img(self, x, upscale=15):
+    def frame2img(self, x, scale=15):
         x = x.reshape(-1, self.height, self.width)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
         x = self.colors[x * m].permute(0, 3, 1, 2)
         s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
-        x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
-        x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
         x = x[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
                     if m[n, i, j] == 0:
-                        for k in range(2, upscale - 2):
-                            x[n, :, i * upscale + k, j * upscale + k] = 0
-                            x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
 
         return x
 
-    def seq2img(self, seq, upscale=15):
+    def seq2img(self, seq, scale=15):
         f_first = seq[:, : self.height * self.width].reshape(
             -1, self.height, self.width
         )
@@ -301,47 +295,53 @@ class Sky:
         direction = seq[:, self.height * self.width]
 
         direction_symbol = torch.full(
-            (direction.size(0), self.height * upscale - 1, upscale), 0
+            (direction.size(0), self.height * scale - 1, scale), 0
         )
         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
-        separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
+        separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0)
 
         for n in range(direction_symbol.size(0)):
             if direction[n] == self.token_forward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + upscale // 2 - abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + scale // 2 - abs(k - scale // 2),
+                        ] = 0
             elif direction[n] == self.token_backward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + abs(k - scale // 2),
+                        ] = 0
             else:
-                for k in range(2, upscale - 2):
-                    direction_symbol[
-                        n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
-                    ] = 0
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        upscale - 1 - k,
-                    ] = 0
+                for k in range(2, scale - 2):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            k,
+                        ] = 0
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            scale - 1 - k,
+                        ] = 0
 
         return torch.cat(
             [
-                self.frame2img(f_first, upscale),
+                self.frame2img(f_first, scale),
                 separator,
                 direction_symbol,
                 separator,
-                self.frame2img(f_second, upscale),
+                self.frame2img(f_second, scale),
             ],
             dim=3,
         )
@@ -352,14 +352,13 @@ class Sky:
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
-    def save_image(self, input, result_dir, filename, logger):
+    def save_image(self, input, result_dir, filename):
         img = self.seq2img(input.to("cpu"))
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-        logger(f"wrote {image_name}")
 
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        self.save_image(input, result_dir, filename_prefix + ".png", logger)
+    def save_quizzes(self, input, result_dir, filename_prefix):
+        self.save_image(input, result_dir, filename_prefix + ".png")
 
 
 ######################################################################
@@ -372,7 +371,7 @@ if __name__ == "__main__":
     start_time = time.perf_counter()
     seq, it = sky.generate_seq(nb=64, return_iterations=True)
     delay = time.perf_counter() - start_time
-    print(f"{seq.size(0)/delay:02f} samples/s")
+    print(f"{seq.size(0)/delay:02f} seq/s")
 
     print(sky.seq2str(seq[:4]))