speed=2,
nb_iterations=2,
avoid_collision=True,
+ max_nb_cached_chunks=None,
+ chunk_size=None,
+ nb_threads=-1,
):
+ super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
self.height = height
self.width = width
self.nb_birds = nb_birds
return prompts, answers
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
- sky.save_quizzes(
+ sky.save_quiz_illustrations(
"/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
)