"_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
)
- def nb_token_values(self):
- return len(self.colors)
-
def __init__(
self,
height=6,
######################################################################
- def generate_prompts_and_answers(self, nb):
- frame_sequences = self.generate_frame_sequences(nb)
- frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
- prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
- answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
- return prompts, answers
-
- ######################################################################
-
def frame2img(self, x, scale=15):
- x = x.reshape(-1, self.height, self.width)
+ x = x.reshape(x.size(0), self.height, -1)
m = torch.logical_and(
x >= 0, x < self.first_bird_token + self.nb_bird_tokens
).long()
img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
)
+ ######################################################################
+
+ def nb_token_values(self):
+ return len(self.colors)
+
+ def generate_prompts_and_answers(self, nb):
+ frame_sequences = self.generate_frame_sequences(nb)
+ frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+ prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+ answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+ return prompts, answers
+
def save_quizzes(
self,
result_dir,
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
prompts, answers = sky.generate_prompts_and_answers(4)