+ ######################################################################
+
+ def nb_token_values(self):
+ return len(self.colors)
+
+ def generate_prompts_and_answers(self, nb):
+ frame_sequences = self.generate_frame_sequences(nb)
+ frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+
+ prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+
+ answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+
+ # warnings.warn("dirty test with longer answer", RuntimeWarning)
+ # answers = torch.cat(
+ # [
+ # frame_sequences[:, frame_sequences.size(1) // 2 :],
+ # frame_sequences[:, frame_sequences.size(1) // 2 :],
+ # ],
+ # dim=3,
+ # ).flatten(1)
+
+ return prompts, answers
+