Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index 1164185..4ca4ba7 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -157,6 +157,12 @@ class Sky(problem.Problem):
 
     ######################################################################
 
 
     ######################################################################
 
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
+        answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
+        return prompts, answers
+
     def generate_token_sequences(self, nb):
         frame_sequences = self.generate_frame_sequences(nb)
 
     def generate_token_sequences(self, nb):
         frame_sequences = self.generate_frame_sequences(nb)