Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index 2183cf1..040ec67 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -42,9 +42,6 @@ class Sky(problem.Problem):
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
-    def nb_token_values(self):
-        return len(self.colors)
-
     def __init__(
         self,
         height=6,
@@ -155,15 +152,6 @@ class Sky(problem.Problem):
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-        return prompts, answers
-
-    ######################################################################
-
     def frame2img(self, x, scale=15):
         x = x.reshape(x.size(0), self.height, -1)
         m = torch.logical_and(
@@ -250,6 +238,18 @@ class Sky(problem.Problem):
             img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
         )
 
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
+
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+        return prompts, answers
+
     def save_quizzes(
         self,
         result_dir,