X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=cc5bd4fd5cfedec896b88d7549d1b831ae2ef265;hb=5f5c6c079c2751a76887444c211c5c464e875ed0;hp=ed440d37a37af02b6b2a19cb6ce20945d3d59afe;hpb=30c76210e3ed2704b2a059208f385cb623c1486d;p=culture.git diff --git a/sky.py b/sky.py index ed440d3..cc5bd4f 100755 --- a/sky.py +++ b/sky.py @@ -50,7 +50,11 @@ class Sky(problem.Problem): speed=2, nb_iterations=2, avoid_collision=True, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, ): + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) self.height = height self.width = width self.nb_birds = nb_birds @@ -296,7 +300,7 @@ class Sky(problem.Problem): return prompts, answers - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -327,7 +331,7 @@ if __name__ == "__main__": predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - sky.save_quizzes( + sky.save_quiz_illustrations( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers )