From dc09c78665cd4a2f1fb2899719e1c50d9fa21696 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 25 Jul 2024 18:15:40 +0200 Subject: [PATCH] Update. --- grids.py | 89 ++++++++++++++++++++++++++++++++++++------------- main.py | 57 ++++++++++++++++++++----------- quiz_machine.py | 2 -- 3 files changed, 103 insertions(+), 45 deletions(-) diff --git a/grids.py b/grids.py index 1d94e07..67a5c97 100755 --- a/grids.py +++ b/grids.py @@ -233,8 +233,7 @@ class Grids(problem.Problem): max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1, - world_tasks=None, - science_tasks=None, + tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) @@ -264,7 +263,7 @@ class Grids(problem.Problem): self.cache_rec_coo = {} - self.all_tasks = [ + all_tasks = [ self.task_replace_color, self.task_translate, self.task_grow, @@ -285,17 +284,10 @@ class Grids(problem.Problem): # self.task_islands, # TOO MESSY ] - if world_tasks is None: - self.world_tasks = self.all_tasks + if tasks is None: + self.all_tasks = all_tasks else: - self.world_tasks = [ - getattr(self, "task_" + t) for t in world_tasks.split(",") - ] - - if science_tasks is not None: - self.science_tasks = [ - getattr(self, "task_" + t) for t in science_tasks.split(",") - ] + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) @@ -1445,6 +1437,59 @@ class Grids(problem.Problem): rs = (S // (2**j)) % 2 f_X[2, -j - 1] = c[2 + rs] + def task_science_implicit(self, A, f_A, B, f_B): + nb_rec = 5 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i1, i2 = torch.randint(self.height, (2,)).sort().values + if i1 >= 1 and i2 < self.height and i1 + 3 < i2: + break + + while True: + j1, j2 = torch.randint(self.width, (2,)).sort().values + if j1 >= 1 and j2 < self.width and j1 + 3 < j2: + break + + f_X[i1:i2, j1:j2] = c[0] + + # --------------------- + + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(j1, (1,)) + X[ii1:ii2, jj:j1] = c[1] + f_X[ii1:ii2, jj:j1] = c[1] + + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(self.width - j2, (1,)) + j2 + 1 + X[ii1:ii2, j2:jj] = c[2] + f_X[ii1:ii2, j2:jj] = c[2] + + # --------------------- + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(i1, (1,)) + X[ii:i1, jj1:jj2] = c[3] + f_X[ii:i1, jj1:jj2] = c[3] + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(self.height - i2, (1,)) + i2 + 1 + X[i2:ii, jj1:jj2] = c[4] + f_X[i2:ii, jj1:jj2] = c[4] + # end_tasks ###################################################################### @@ -1459,14 +1504,11 @@ class Grids(problem.Problem): return quizzes - def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False): + def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): S = self.height * self.width if tasks is None: - if science: - tasks = self.science_tasks - else: - tasks = self.world_tasks + tasks = self.all_tasks quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) @@ -1487,10 +1529,9 @@ class Grids(problem.Problem): return quizzes - def save_some_examples(self, result_dir, science=False): + def save_some_examples(self, result_dir): nb, nrow = 128, 4 - tasks = self.science_tasks if science else self.world_tasks - for t in tasks: + for t in self.all_tasks: print(t.__name__) quizzes = self.generate_w_quizzes_(nb, tasks=[t]) self.save_quizzes_as_image( @@ -1545,9 +1586,9 @@ if __name__ == "__main__": nb, nrow = 128, 4 # nb, nrow = 8, 2 - # for t in grids.world_tasks: + # for t in grids.all_tasks: - for t in [grids.task_path]: + for t in [grids.task_science_implicit]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( @@ -1557,7 +1598,7 @@ if __name__ == "__main__": comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))], ) - # exit(0) + exit(0) nb = 1000 diff --git a/main.py b/main.py index b49fa06..4d618cc 100755 --- a/main.py +++ b/main.py @@ -125,14 +125,6 @@ parser.add_argument( help="A comma-separated subset of: " + grids_tasks + ", or None for all.", ) -assert ( - len( - set(args.grids_world_tasks.split(",")) - & set(args.grids_science_tasks.split(",")) - ) - == 0 -), "World and science task have to be disjoint" - ###################################################################### parser.add_argument("--sky_height", type=int, default=6) @@ -152,6 +144,14 @@ args = parser.parse_args() if args.result_dir is None: args.result_dir = f"results_culture" +assert not args.grids_science_tasks or ( + len( + set(args.grids_world_tasks.split(",")) + & set(args.grids_science_tasks.split(",")) + ) + == 0 +), "World and science tasks have to be disjoint" + ###################################################################### default_args = { @@ -304,17 +304,25 @@ if args.problem == "sky": chunk_size=100, nb_threads=args.nb_threads, ) - back_accuracy = False elif args.problem == "grids": problem = grids.Grids( max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, - world_tasks=args.grids_world_tasks, - science_tasks=args.grids_science_tasks, + tasks=args.grids_world_tasks, ) - back_accuracy = True + + if args.grids_science_tasks is None: + science_w_quizzes = None + else: + science_problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_science_tasks, + ) + science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples) else: raise ValueError @@ -324,7 +332,6 @@ if not args.resume: quiz_machine = quiz_machine.QuizMachine( problem=problem, - back_accuracy=back_accuracy, batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, @@ -600,12 +607,6 @@ for k in range(args.nb_gpts): ###################################################################### -science_w_quizzes = quiz_machine.problem.generate_w_quizzes( - args.nb_test_samples, science=True -) - -###################################################################### - current_epoch = 0 if args.resume: @@ -757,6 +758,24 @@ for n_epoch in range(current_epoch, args.nb_epochs): c_quizzes, ) + ###################################################################### + + if science_w_quizzes is not None: + result, correct = quiz_machine.predict( + model=model, + quizzes=science_w_quizzes.to(main_device), + struct=("A", "f_A", "B", "f_B"), + mask=(0, 0, 0, 1), + ) + + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + log_string( + f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" + ) + + ###################################################################### + # Renew the training samples for model in weakest_models: diff --git a/quiz_machine.py b/quiz_machine.py index 4048b39..8e40921 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -115,7 +115,6 @@ class QuizMachine: def __init__( self, problem, - back_accuracy, batch_size, result_dir, logger, @@ -124,7 +123,6 @@ class QuizMachine: super().__init__() self.problem = problem - self.back_accuracy = back_accuracy self.batch_size = batch_size self.device = device self.logger = logger -- 2.39.5