Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index e509fb7..1768a81 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -50,7 +50,11 @@ class Sky(problem.Problem):
         speed=2,
         nb_iterations=2,
         avoid_collision=True,
+        max_nb_cached_chunks=None,
+        chunk_size=None,
+        nb_threads=-1,
     ):
+        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
         self.height = height
         self.width = width
         self.nb_birds = nb_birds
@@ -217,9 +221,11 @@ class Sky(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([0, 0, 0], device=c.device) + (
-                    1 - c
-                ) * torch.tensor([255, 255, 255], device=c.device)
+                c = (
+                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
@@ -229,6 +235,7 @@ class Sky(problem.Problem):
         margin = 4
 
         img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+        h = img_prompts.size(2)
         img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
 
         img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
@@ -241,7 +248,7 @@ class Sky(problem.Problem):
             img_answers, c=predicted_answers, margin=margin, bottom=True
         )
 
-        marker_size = 8
+        marker_size = 16
 
         separator = img_prompts.new_full(
             (
@@ -253,17 +260,20 @@ class Sky(problem.Problem):
             255,
         )
 
-        for k in range(2, 2 * marker_size - 3):
-            i = k + 1 - marker_size
-            j = marker_size - 2 - abs(k - marker_size + 1)
-            separator[:, :, separator.size(2) // 2 + i, j] = 0
-            separator[:, :, separator.size(2) // 2 + i + 1, j] = 0
+        separator[:, :, 0] = 0
+        separator[:, :, h - 1] = 0
+
+        for k in range(1, 2 * marker_size - 8):
+            i = k - (marker_size - 4)
+            j = marker_size - 5 - abs(i)
+            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
 
         img = torch.cat([img_prompts, separator, img_answers], dim=3)
 
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(
-            img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
+            img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
         )
 
     ######################################################################
@@ -278,6 +288,7 @@ class Sky(problem.Problem):
         prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
 
         answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+
         # warnings.warn("dirty test with longer answer", RuntimeWarning)
         # answers = torch.cat(
         # [
@@ -317,8 +328,8 @@ if __name__ == "__main__":
 
     prompts, answers = sky.generate_prompts_and_answers(4)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.rand(answers.size(0)) < 0.5
+    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
+    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
     sky.save_quizzes(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers