Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index d2a4568..040ec67 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -42,9 +42,6 @@ class Sky(problem.Problem):
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
-    def nb_token_values(self):
-        return len(self.colors)
-
     def __init__(
         self,
         height=6,
     def __init__(
         self,
         height=6,
@@ -155,17 +152,8 @@ class Sky(problem.Problem):
 
     ######################################################################
 
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-        return prompts, answers
-
-    ######################################################################
-
     def frame2img(self, x, scale=15):
     def frame2img(self, x, scale=15):
-        x = x.reshape(-1, self.height, self.width)
+        x = x.reshape(x.size(0), self.height, -1)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
@@ -250,6 +238,18 @@ class Sky(problem.Problem):
             img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
         )
 
             img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
         )
 
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
+
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+        return prompts, answers
+
     def save_quizzes(
         self,
         result_dir,
     def save_quizzes(
         self,
         result_dir,
@@ -274,7 +274,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
 
     prompts, answers = sky.generate_prompts_and_answers(4)
 
 
     prompts, answers = sky.generate_prompts_and_answers(4)