Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index a90e37d..ac6cbdc 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm
+import math, sys, tqdm, os
 
 import torch, torchvision
 
 
 import torch, torchvision
 
@@ -14,8 +14,10 @@ from torch.nn import functional as F
 
 ######################################################################
 
 
 ######################################################################
 
+import problem
 
 
-class Sky:
+
+class Sky(problem.Problem):
     colors = torch.tensor(
         [
             [255, 255, 255],
     colors = torch.tensor(
         [
             [255, 255, 255],
@@ -48,6 +50,9 @@ class Sky:
         self.nb_birds = nb_birds
         self.nb_iterations = nb_iterations
 
         self.nb_birds = nb_birds
         self.nb_iterations = nb_iterations
 
+    def direction_tokens(self):
+        return self.token_forward, self.token_backward
+
     def generate_seq(self, nb, return_iterations=False):
         pairs = []
         kept_iterations = []
     def generate_seq(self, nb, return_iterations=False):
         pairs = []
         kept_iterations = []
@@ -253,31 +258,34 @@ class Sky:
 
         return torch.cat(result, dim=0)
 
 
         return torch.cat(result, dim=0)
 
-    def frame2img(self, x, upscale=15):
+    def frame2img(self, x, scale=15):
         x = x.reshape(-1, self.height, self.width)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
         x = self.colors[x * m].permute(0, 3, 1, 2)
         s = x.shape
         x = x.reshape(-1, self.height, self.width)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
         x = self.colors[x * m].permute(0, 3, 1, 2)
         s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
-        x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
 
-        x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
-        x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
         x = x[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
                     if m[n, i, j] == 0:
         x = x[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
                     if m[n, i, j] == 0:
-                        for k in range(2, upscale - 2):
-                            x[n, :, i * upscale + k, j * upscale + k] = 0
-                            x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
 
         return x
 
 
         return x
 
-    def seq2img(self, seq, upscale=15):
+    def seq2img(self, seq, scale=15):
         f_first = seq[:, : self.height * self.width].reshape(
             -1, self.height, self.width
         )
         f_first = seq[:, : self.height * self.width].reshape(
             -1, self.height, self.width
         )
@@ -287,47 +295,53 @@ class Sky:
         direction = seq[:, self.height * self.width]
 
         direction_symbol = torch.full(
         direction = seq[:, self.height * self.width]
 
         direction_symbol = torch.full(
-            (direction.size(0), self.height * upscale - 1, upscale), 0
+            (direction.size(0), self.height * scale - 1, scale), 0
         )
         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
         )
         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
-        separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
+        separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0)
 
         for n in range(direction_symbol.size(0)):
             if direction[n] == self.token_forward:
 
         for n in range(direction_symbol.size(0)):
             if direction[n] == self.token_forward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + upscale // 2 - abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + scale // 2 - abs(k - scale // 2),
+                        ] = 0
             elif direction[n] == self.token_backward:
             elif direction[n] == self.token_backward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + abs(k - scale // 2),
+                        ] = 0
             else:
             else:
-                for k in range(2, upscale - 2):
-                    direction_symbol[
-                        n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
-                    ] = 0
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        upscale - 1 - k,
-                    ] = 0
+                for k in range(2, scale - 2):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            k,
+                        ] = 0
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            scale - 1 - k,
+                        ] = 0
 
         return torch.cat(
             [
 
         return torch.cat(
             [
-                self.frame2img(f_first, upscale),
+                self.frame2img(f_first, scale),
                 separator,
                 direction_symbol,
                 separator,
                 separator,
                 direction_symbol,
                 separator,
-                self.frame2img(f_second, upscale),
+                self.frame2img(f_second, scale),
             ],
             dim=3,
         )
             ],
             dim=3,
         )
@@ -338,6 +352,14 @@ class Sky:
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
+    def save_image(self, input, result_dir, filename):
+        img = self.seq2img(input.to("cpu"))
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+
+    def save_quizzes(self, input, result_dir, filename_prefix):
+        self.save_image(input, result_dir, filename_prefix + ".png")
+
 
 ######################################################################
 
 
 ######################################################################
 
@@ -349,7 +371,7 @@ if __name__ == "__main__":
     start_time = time.perf_counter()
     seq, it = sky.generate_seq(nb=64, return_iterations=True)
     delay = time.perf_counter() - start_time
     start_time = time.perf_counter()
     seq, it = sky.generate_seq(nb=64, return_iterations=True)
     delay = time.perf_counter() - start_time
-    print(f"{seq.size(0)/delay:02f} samples/s")
+    print(f"{seq.size(0)/delay:02f} seq/s")
 
     print(sky.seq2str(seq[:4]))
 
 
     print(sky.seq2str(seq[:4]))