Merge branch 'dev'
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index ed440d3..cc5bd4f 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -50,7 +50,11 @@ class Sky(problem.Problem):
         speed=2,
         nb_iterations=2,
         avoid_collision=True,
+        max_nb_cached_chunks=None,
+        chunk_size=None,
+        nb_threads=-1,
     ):
+        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
         self.height = height
         self.width = width
         self.nb_birds = nb_birds
@@ -296,7 +300,7 @@ class Sky(problem.Problem):
 
         return prompts, answers
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -327,7 +331,7 @@ if __name__ == "__main__":
     predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
     predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
-    sky.save_quizzes(
+    sky.save_quiz_illustrations(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
     )