X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=sky.py;h=4ca4ba7136b40a5324dcb64ba4c6a3a19523b3c5;hb=c9c018e4c19ce92892d7652082fb90719d57441c;hp=11641853d8e8081f1cc7d0cf63d112f1ba30b518;hpb=aae01e186a959131b446d0365c6b951bacfd71d9;p=culture.git diff --git a/sky.py b/sky.py index 1164185..4ca4ba7 100755 --- a/sky.py +++ b/sky.py @@ -157,6 +157,12 @@ class Sky(problem.Problem): ###################################################################### + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1) + return prompts, answers + def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb)