X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grids.py;h=002a33ffe4d56bfe0459b19cf6897a97e55babd4;hb=2f87c91cf606a068de1450d198660de7e44cd356;hp=7aec62c16dd26f88ea9ea9d4938e099646f4f1ba;hpb=12c775dcbd3d3cd703f35c181faa6d2a680a0450;p=culture.git diff --git a/grids.py b/grids.py index 7aec62c..002a33f 100755 --- a/grids.py +++ b/grids.py @@ -37,11 +37,34 @@ class Grids(problem.Problem): max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1, + tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) self.height = 10 self.width = 10 self.cache_rec_coo = {} + + all_tasks = [ + self.task_replace_color, + self.task_translate, + self.task_grow, + self.task_half_fill, + self.task_frame, + self.task_detect, + self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, + self.task_symbols, + self.task_isometry, + # self.task_path, + ] + + if tasks is None: + self.all_tasks = all_tasks + else: + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) ###################################################################### @@ -398,7 +421,7 @@ class Grids(problem.Problem): f_X[i1:i2, j1:j2] = c[n] # @torch.compile - def task_color_grow(self, A, f_A, B, f_B): + def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 @@ -715,7 +738,7 @@ class Grids(problem.Problem): f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] # @torch.compile - def task_ortho(self, A, f_A, B, f_B): + def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 di, dj = torch.randint(3, (2,)) - 1 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) @@ -939,23 +962,6 @@ class Grids(problem.Problem): ###################################################################### - def all_tasks(self): - return [ - self.task_replace_color, - self.task_translate, - self.task_grow, - self.task_color_grow, - self.task_frame, - self.task_detect, - self.task_count, - self.task_trajectory, - self.task_bounce, - self.task_scale, - self.task_symbols, - self.task_ortho, - # self.task_path, - ] - def trivial_prompts_and_answers(self, prompts, answers): S = self.height * self.width Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] @@ -964,7 +970,7 @@ class Grids(problem.Problem): def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): if tasks is None: - tasks = self.all_tasks() + tasks = self.all_tasks S = self.height * self.width prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) @@ -990,7 +996,7 @@ class Grids(problem.Problem): return prompts.flatten(1), answers.flatten(1) - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -1012,10 +1018,10 @@ class Grids(problem.Problem): def save_some_examples(self, result_dir): nb, nrow = 72, 4 - for t in self.all_tasks(): + for t in self.all_tasks: print(t.__name__) prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) - self.save_quizzes( + self.save_quiz_illustrations( result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) @@ -1043,18 +1049,23 @@ if __name__ == "__main__": nb, nrow = 72, 4 # nb, nrow = 8, 2 - # for t in grids.all_tasks(): - for t in [grids.task_puzzle]: + # for t in grids.all_tasks: + for t in [ + grids.task_replace_color, + grids.task_frame, + ]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow) + grids.save_quiz_illustrations( + "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow + ) exit(0) nb = 1000 - for t in grids.all_tasks(): - # for t in [ grids.task_replace_color ]: #grids.all_tasks(): + for t in grids.all_tasks: + # for t in [ grids.task_replace_color ]: #grids.all_tasks: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time @@ -1066,7 +1077,7 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quizzes( + grids.save_quiz_illustrations( "/tmp", "test", prompts[:nb],