X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=4ca4ba7136b40a5324dcb64ba4c6a3a19523b3c5;hb=c9c018e4c19ce92892d7652082fb90719d57441c;hp=11641853d8e8081f1cc7d0cf63d112f1ba30b518;hpb=a06227e98fcef1960b8706c9a1ce72d10bd068c3;p=culture.git diff --git a/sky.py b/sky.py index 1164185..4ca4ba7 100755 --- a/sky.py +++ b/sky.py @@ -157,6 +157,12 @@ class Sky(problem.Problem): ###################################################################### + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1) + return prompts, answers + def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb)