Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 18:38:31 +0000 (20:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 18:38:31 +0000 (20:38 +0200)
sky.py

diff --git a/sky.py b/sky.py
index 1e6ed4d..3584beb 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -258,31 +258,34 @@ class Sky(problem.Problem):
 
         return torch.cat(result, dim=0)
 
-    def frame2img(self, x, upscale=15):
+    def frame2img(self, x, scale=15):
         x = x.reshape(-1, self.height, self.width)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
         x = self.colors[x * m].permute(0, 3, 1, 2)
         s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
-        x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
-        x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
         x = x[:, :, 1:, 1:]
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
                     if m[n, i, j] == 0:
-                        for k in range(2, upscale - 2):
-                            x[n, :, i * upscale + k, j * upscale + k] = 0
-                            x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
 
         return x
 
-    def seq2img(self, seq, upscale=15):
+    def seq2img(self, seq, scale=15):
         f_first = seq[:, : self.height * self.width].reshape(
             -1, self.height, self.width
         )
@@ -292,47 +295,53 @@ class Sky(problem.Problem):
         direction = seq[:, self.height * self.width]
 
         direction_symbol = torch.full(
-            (direction.size(0), self.height * upscale - 1, upscale), 0
+            (direction.size(0), self.height * scale - 1, scale), 0
         )
         direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2)
-        separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0)
+        separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0)
 
         for n in range(direction_symbol.size(0)):
             if direction[n] == self.token_forward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + upscale // 2 - abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + scale // 2 - abs(k - scale // 2),
+                        ] = 0
             elif direction[n] == self.token_backward:
-                for k in range(upscale):
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        3 + abs(k - upscale // 2),
-                    ] = 0
+                for k in range(scale):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            3 + abs(k - scale // 2),
+                        ] = 0
             else:
-                for k in range(2, upscale - 2):
-                    direction_symbol[
-                        n, :, (self.height * upscale) // 2 - upscale // 2 + k, k
-                    ] = 0
-                    direction_symbol[
-                        n,
-                        :,
-                        (self.height * upscale) // 2 - upscale // 2 + k,
-                        upscale - 1 - k,
-                    ] = 0
+                for k in range(2, scale - 2):
+                    for l in [0, 1]:
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            k,
+                        ] = 0
+                        direction_symbol[
+                            n,
+                            :,
+                            (self.height * scale) // 2 - scale // 2 + k - l,
+                            scale - 1 - k,
+                        ] = 0
 
         return torch.cat(
             [
-                self.frame2img(f_first, upscale),
+                self.frame2img(f_first, scale),
                 separator,
                 direction_symbol,
                 separator,
-                self.frame2img(f_second, upscale),
+                self.frame2img(f_second, scale),
             ],
             dim=3,
         )