From d4a268427ccb0da17d5cb918124602b994f9397a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 25 Jul 2024 16:54:46 +0200 Subject: [PATCH] Update. --- grids.py | 73 +++++++++++++++++++++++++++++++++++++++++-------- main.py | 61 ++++++++++++++++++++++++++++------------- quiz_machine.py | 56 +++++++++++++++---------------------- 3 files changed, 126 insertions(+), 64 deletions(-) diff --git a/grids.py b/grids.py index 93b027a..1d94e07 100755 --- a/grids.py +++ b/grids.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm, os, warnings +import math, sys, tqdm, os, warnings, cairo import torch, torchvision @@ -14,6 +14,36 @@ from torch.nn import functional as F ###################################################################### + +def text_img(height, width, text): + pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8) + + surface = cairo.ImageSurface.create_for_data( + pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0) + ) + + ctx = cairo.Context(surface) + ctx.set_source_rgb(0, 0, 0) + ctx.set_font_size(16) + ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + y = None + for line in text.split("\n"): + xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line) + if y is None: + y = height * 1.5 + x = height * 0.5 + + ctx.move_to(x, y) + ctx.show_text(line) + y += height * 1.5 + + ctx.stroke() + + return pixel_map.permute(2, 0, 1)[None, :3].contiguous() + + +###################################################################### + import problem @@ -203,7 +233,8 @@ class Grids(problem.Problem): max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1, - tasks=None, + world_tasks=None, + science_tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) @@ -233,7 +264,7 @@ class Grids(problem.Problem): self.cache_rec_coo = {} - all_tasks = [ + self.all_tasks = [ self.task_replace_color, self.task_translate, self.task_grow, @@ -254,10 +285,17 @@ class Grids(problem.Problem): # self.task_islands, # TOO MESSY ] - if tasks is None: - self.all_tasks = all_tasks + if world_tasks is None: + self.world_tasks = self.all_tasks else: - self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] + self.world_tasks = [ + getattr(self, "task_" + t) for t in world_tasks.split(",") + ] + + if science_tasks is not None: + self.science_tasks = [ + getattr(self, "task_" + t) for t in science_tasks.split(",") + ] super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) @@ -305,6 +343,8 @@ class Grids(problem.Problem): quizzes, predicted_parts=None, correct_parts=None, + comments=None, + comment_height=64, nrow=4, margin=8, ): @@ -365,6 +405,11 @@ class Grids(problem.Problem): img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) + if comments is not None: + comment_img = [text_img(comment_height, img.size(3), t) for t in comments] + comment_img = torch.cat(comment_img, dim=0) + img = torch.cat([img, comment_img], dim=2) + image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( @@ -1414,11 +1459,14 @@ class Grids(problem.Problem): return quizzes - def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): + def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False): S = self.height * self.width if tasks is None: - tasks = self.all_tasks + if science: + tasks = self.science_tasks + else: + tasks = self.world_tasks quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) @@ -1439,9 +1487,10 @@ class Grids(problem.Problem): return quizzes - def save_some_examples(self, result_dir): + def save_some_examples(self, result_dir, science=False): nb, nrow = 128, 4 - for t in self.all_tasks: + tasks = self.science_tasks if science else self.world_tasks + for t in tasks: print(t.__name__) quizzes = self.generate_w_quizzes_(nb, tasks=[t]) self.save_quizzes_as_image( @@ -1496,7 +1545,8 @@ if __name__ == "__main__": nb, nrow = 128, 4 # nb, nrow = 8, 2 - # for t in grids.all_tasks: + # for t in grids.world_tasks: + for t in [grids.task_path]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) @@ -1504,6 +1554,7 @@ if __name__ == "__main__": "/tmp", t.__name__ + ".png", quizzes, + comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))], ) # exit(0) diff --git a/main.py b/main.py index 257f40f..b49fa06 100755 --- a/main.py +++ b/main.py @@ -112,12 +112,27 @@ grids_tasks = ", ".join( ) parser.add_argument( - "--grids_tasks", + "--grids_world_tasks", type=str, default=None, help="A comma-separated subset of: " + grids_tasks + ", or None for all.", ) +parser.add_argument( + "--grids_science_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None for all.", +) + +assert ( + len( + set(args.grids_world_tasks.split(",")) + & set(args.grids_science_tasks.split(",")) + ) + == 0 +), "World and science task have to be disjoint" + ###################################################################### parser.add_argument("--sky_height", type=int, default=6) @@ -290,14 +305,17 @@ if args.problem == "sky": nb_threads=args.nb_threads, ) back_accuracy = False + elif args.problem == "grids": problem = grids.Grids( max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, - tasks=args.grids_tasks, + world_tasks=args.grids_world_tasks, + science_tasks=args.grids_science_tasks, ) back_accuracy = True + else: raise ValueError @@ -465,9 +483,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # fail(s) # This is nb_quizzes x nb_models - number_correct_responses = 0 - remains = [c_quizzes.size(0)] + number_correct_responses = 0 + nb_remaining = [c_quizzes.size(0)] for r in range(args.nb_rounds): if c_quizzes.size(0) == 0: @@ -487,7 +505,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[to_keep] number_correct_responses = number_correct_responses[to_keep] - remains.append(c_quizzes.size(0)) + nb_remaining.append(c_quizzes.size(0)) if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) @@ -512,7 +530,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 else: e = "???" - v = " ".join([str(n) for n in remains]) + v = " ".join([str(n) for n in nb_remaining]) log_string(f"filter c_quizzes {v}") log_string( @@ -526,11 +544,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 v_train = validated_quizzes[:nb_for_train] quiz_machine.store_c_quizzes(v_train, for_train=True) - quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True) v_test = validated_quizzes[nb_for_train:nb_to_validate] quiz_machine.store_c_quizzes(v_test, for_train=False) - quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False) ###################################################################### # save images @@ -538,19 +554,19 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: - prefix = f"culture_c_quiz_{n_epoch:04d}" - number_correct_responses = 0 for r in range(args.nb_rounds): number_correct_responses += quiz_machine.models_successes(models, vq) - with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f: - for n, r in enumerate(number_correct_responses): - v = " ".join([str(n.item()) for n in r]) - f.write(f"{n}: {v}\n") + comments = [] + for r in number_correct_responses: + comments.append("nb_correct " + " ".join([str(n.item()) for n in r])) vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B")) - quiz_machine.problem.save_quizzes_as_image(args.result_dir, prefix, vq) + filename = f"culture_c_quiz_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, vq, comments=comments + ) ###################################################################### @@ -574,16 +590,22 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - quiz_machine.create_w_quizzes( - model=model, - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, + model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( + args.nb_train_samples ) + model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + models.append(model) ###################################################################### +science_w_quizzes = quiz_machine.problem.generate_w_quizzes( + args.nb_test_samples, science=True +) + +###################################################################### + current_epoch = 0 if args.resume: @@ -728,6 +750,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): c_quizzes = quiz_machine.problem.reconfigure( c_quizzes, ("A", "f_A", "B", "f_B") ) + quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}.png", diff --git a/quiz_machine.py b/quiz_machine.py index 2ca584e..4048b39 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -172,14 +172,16 @@ class QuizMachine: from_w = torch.arange( quizzes.size(0), device=quizzes.device ) < w_quizzes.size(0) - i = torch.randperm(quizzes.size(0), device=quizzes.device) - - return quizzes[i], from_w[i] else: - return w_quizzes, torch.full( - (w_quizzes.size(0),), True, device=w_quizzes.device - ) + quizzes = w_quizzes.clone() + from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) + + self.randomize_configuations_inplace(quizzes, structs=self.train_struct) + + i = torch.randperm(quizzes.size(0), device=quizzes.device) + + return quizzes[i], from_w[i] ###################################################################### @@ -199,7 +201,6 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, - deterministic_synthesis=False, progress_bar_desc="accuracy", device=self.device, ) @@ -219,7 +220,9 @@ class QuizMachine: result = input.new(input.size()) correct = input.new(input.size(0)) predicted_parts = input.new(input.size(0), 4) + nb = 0 + for struct, mask in [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), @@ -278,47 +281,38 @@ class QuizMachine: quizzes[r == c], struct=structs[c] ) - def create_w_quizzes(self, model, nb_train_samples, nb_test_samples): - model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples) - model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples) - - self.randomize_configuations_inplace( - model.train_w_quizzes, structs=self.train_struct - ) - - self.randomize_configuations_inplace( - model.test_w_quizzes, structs=self.train_struct - ) - ###################################################################### def renew_train_w_quizzes(self, model): if hasattr(model, "hard_w_quizzes"): - self.logger( - f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" - ) - if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): + nb_to_generate = 0 model.train_w_quizzes[...] = model.hard_w_quizzes[ torch.randperm(hard_w_quizzes.size(0))[ model.train_w_quizzes.size(0) ] ] else: + nb_to_generate = model.train_w_quizzes.size( + 0 + ) - model.hard_w_quizzes.size(0) model.train_w_quizzes[...] = torch.cat( [ model.hard_w_quizzes, - self.problem.generate_w_quizzes( - model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0) - ), + self.problem.generate_w_quizzes(nb_to_generate), ], dim=0, ) else: + nb_to_generate = 0 model.train_w_quizzes[...] = self.problem.generate_w_quizzes( model.train_w_quizzes.size(0) ) + self.logger( + f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" + ) + self.randomize_configuations_inplace( model.train_w_quizzes, structs=self.train_struct ) @@ -409,7 +403,6 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - deterministic_synthesis=False, device=self.device, ) @@ -430,7 +423,6 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - deterministic_synthesis=False, device=self.device, ) @@ -451,9 +443,8 @@ class QuizMachine: temperature_hot=1.0, temperature_cold=1.0, ): - c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to( - self.device - ) + c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")) + c_quizzes = c_quizzes.to(self.device) seq_logproba = torch.zeros(nb, device=self.device) @@ -469,7 +460,6 @@ class QuizMachine: ), seq_logproba=seq_logproba, logit_transformer=lt_noisy, - deterministic_synthesis=False, device=self.device, ) @@ -482,7 +472,6 @@ class QuizMachine: ), seq_logproba=seq_logproba, logit_transformer=lt_clean, - deterministic_synthesis=False, device=self.device, ) @@ -497,7 +486,6 @@ class QuizMachine: ), seq_logproba=seq_logproba, logit_transformer=lt_clean, - deterministic_synthesis=False, device=self.device, ) -- 2.20.1