From 15192743a5dee8d88650319d64610f1603d21472 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 09:57:35 +0200 Subject: [PATCH] Update. --- main.py | 68 +++++++++++++++++++---------------- tasks.py | 106 +++++++++++++++++++++++++++++-------------------------- 2 files changed, 93 insertions(+), 81 deletions(-) diff --git a/main.py b/main.py index 8033836..ebecad8 100755 --- a/main.py +++ b/main.py @@ -14,6 +14,14 @@ from torch.nn import functional as F import ffutils import mygpt, tasks +# world quizzes vs. culture quizzes + +###################################################################### + +accuracy_to_make_c_quizzes = 0.975 +nb_new_c_quizzes_for_train = 1000 +nb_new_c_quizzes_for_test = 100 + ###################################################################### if torch.cuda.is_available(): @@ -84,6 +92,13 @@ if args.result_dir is None: ###################################################################### +if args.dirty_debug: + accuracy_to_make_c_quizzes = 0.0 + nb_new_c_quizzes_for_train = 100 + nb_new_c_quizzes_for_test = 10 + +###################################################################### + default_args = { "model": "37M", "batch_size": 100, @@ -329,7 +344,7 @@ def run_tests(model, task, deterministic_synthesis): ###################################################################### -def create_quizzes( +def create_c_quizzes( model, other_models, task, @@ -339,12 +354,12 @@ def create_quizzes( ): kept = [] - sum_logits, sum_nb_quizzes = 0, 0 + sum_logits, sum_nb_c_quizzes = 0, 0 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: nb_to_generate = 4 * (nb_for_train + nb_for_test) - new_quizzes, nb_correct, average_logits = task.create_new_quizzes( + new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, @@ -354,33 +369,33 @@ def create_quizzes( desired_average_logits=desired_average_logits, ) - sum_logits += new_quizzes.size(0) * average_logits - sum_nb_quizzes += new_quizzes.size(0) + sum_logits += new_c_quizzes.size(0) * average_logits + sum_nb_c_quizzes += new_c_quizzes.size(0) - to_keep = new_quizzes[nb_correct == len(other_models) - 1] + to_keep = new_c_quizzes[nb_correct == len(other_models) - 1] if args.dirty_debug: - to_keep = new_quizzes + to_keep = new_c_quizzes log_string( - f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)" + f"keep {to_keep.size(0)}/{new_c_quizzes.size(0)} c_quizzes ({to_keep.size(0)*100/new_c_quizzes.size(0):.02f}%)" ) kept.append(to_keep) - new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] + new_c_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] - task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True) - task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False) + task.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) + task.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - task.save_image( - new_quizzes[:72], + task.save_quizzes( + new_c_quizzes[:72], args.result_dir, - f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", + f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}", log_string, ) - return sum_logits / sum_nb_quizzes + return sum_logits / sum_nb_c_quizzes ###################################################################### @@ -410,15 +425,6 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -accuracy_to_make_quizzes = 0.975 -nb_new_quizzes_for_train = 1000 -nb_new_quizzes_for_test = 100 - -if args.dirty_debug: - accuracy_to_make_quizzes = 0.0 - nb_new_quizzes_for_train = 100 - nb_new_quizzes_for_test = 10 - desired_average_logits = None for n_epoch in range(args.nb_epochs): @@ -439,29 +445,29 @@ for n_epoch in range(args.nb_epochs): # improve it one_epoch(model, task) - task.renew_samples(args.nb_train_samples // args.nb_gpts) + task.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) log_string( - f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" + f"train_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}" ) # test it run_tests(model, task, deterministic_synthesis=False) log_string( - f"test_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" + f"test_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}" ) - if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes: + if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_c_quizzes: other_models = models.copy() other_models.remove(model) - average_logits = create_quizzes( + average_logits = create_c_quizzes( model, other_models, task, - nb_for_train=nb_new_quizzes_for_train, - nb_for_test=nb_new_quizzes_for_test, + nb_for_train=nb_new_c_quizzes_for_train, + nb_for_test=nb_new_c_quizzes_for_test, desired_average_logits=desired_average_logits, ) diff --git a/tasks.py b/tasks.py index ee06c25..43f7d53 100755 --- a/tasks.py +++ b/tasks.py @@ -88,6 +88,9 @@ class World(Task): torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") + def save_quizzes(self, input, result_dir, filename_prefix, logger): + self.save_image(input, result_dir, filename_prefix + ".png", logger) + def make_ar_mask(self, input): b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 return b.long()[None, :].expand_as(input) @@ -108,49 +111,52 @@ class World(Task): self.height = 6 self.width = 8 - self.train_input = world.generate_seq( + self.train_w_quizzes = world.generate_seq( nb_train_samples, height=self.height, width=self.width ).to(device) - self.test_input = world.generate_seq( + self.test_w_quizzes = world.generate_seq( nb_test_samples, height=self.height, width=self.width ).to(device) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1 - self.train_quizzes = [] - self.test_quizzes = [] + self.train_c_quizzes = [] + self.test_c_quizzes = [] if result_dir is not None: - self.save_image( - self.train_input[:72], result_dir, f"world_train.png", logger + self.save_quizzes( + self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger ) def batches(self, split="train", desc=None): assert split in {"train", "test"} if split == "train": - input = self.train_input - quizzes = self.train_quizzes + w_quizzes = self.train_w_quizzes + c_quizzes = self.train_c_quizzes else: - input = self.test_input - quizzes = self.test_quizzes + w_quizzes = self.test_w_quizzes + c_quizzes = self.test_c_quizzes - if len(quizzes) > 0: - quizzes = torch.cat(quizzes, dim=0) - if quizzes.size(0) > input.size(0) // 2: - i = torch.randperm(input.size(0))[: input.size(0) // 2] - quizzes = quizzes[i] + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] - i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)] - input = input[i] + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = quizzes.size(0) + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) - input = torch.cat([input, quizzes], dim=0) + input = torch.cat([w_quizzes, c_quizzes], dim=0) else: - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = 0 + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -192,13 +198,13 @@ class World(Task): return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_input) + train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger) + test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -209,7 +215,7 @@ class World(Task): ############################## - input = self.test_input[:96] + input = self.test_w_quizzes[:96] ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) @@ -225,30 +231,30 @@ class World(Task): device=self.device, ) - self.save_image( + self.save_quizzes( result[:72], result_dir, - f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", + f"culture_prediction_{n_epoch:04d}_{model.id:02d}", logger, ) return main_test_accuracy - def renew_samples(self, nb, for_train=True): - input = self.train_input if for_train else self.test_input + def renew_w_quizzes(self, nb, for_train=True): + input = self.train_w_quizzes if for_train else self.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( self.device ) - def store_new_quizzes(self, new_quizzes, for_train=True): + def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: - self.train_quizzes.append(new_quizzes) + self.train_c_quizzes.append(new_c_quizzes) else: - self.test_quizzes.append(new_quizzes) + self.test_c_quizzes.append(new_c_quizzes) - def create_new_quizzes( + def create_c_quizzes( self, n_epoch, result_dir, @@ -261,11 +267,11 @@ class World(Task): ############################################################### # Generate quizzes with model - quizzes = torch.empty( + c_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(quizzes.size(), 1, device=self.device) + ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) summed_logits = torch.empty(nb, device=self.device) temperature = 1 @@ -277,12 +283,12 @@ class World(Task): masked_inplace_autoregression( model=model, batch_size=self.batch_size, - input=quizzes, + input=c_quizzes, ar_mask=ar_mask, summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=False, - progress_bar_desc="creating quizzes", + progress_bar_desc="sampling c_quizzes", device=self.device, ) @@ -311,15 +317,15 @@ class World(Task): # Create the reverse quizzes l = self.height * self.width - direction = quizzes[:, l : l + 1] + direction = c_quizzes[:, l : l + 1] direction = world.token_forward * ( direction == world.token_backward ) + world.token_backward * (direction == world.token_forward) - reverse_quizzes = torch.cat( - [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 + reverse_c_quizzes = torch.cat( + [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 ) - ar_mask = self.make_ar_mask(quizzes) + ar_mask = self.make_ar_mask(c_quizzes) ############################################################### # Check how many of the other models can solve them in both @@ -328,7 +334,7 @@ class World(Task): nb_correct = [] for m in other_models: - result = quizzes.clone() + result = c_quizzes.clone() masked_inplace_autoregression( model=m, @@ -338,13 +344,13 @@ class World(Task): summed_logits=None, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving quizzes", + progress_bar_desc="solving c_quizzes", device=self.device, ) - correct = (quizzes == result).long().min(dim=-1).values + correct = (c_quizzes == result).long().min(dim=-1).values - reverse_result = reverse_quizzes.clone() + reverse_result = reverse_c_quizzes.clone() masked_inplace_autoregression( model=m, @@ -354,21 +360,21 @@ class World(Task): summed_logits=None, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving reversed quizzes", + progress_bar_desc="solving reversed c_quizzes", device=self.device, ) reverse_correct = ( - (reverse_quizzes == reverse_result).long().min(dim=-1).values + (reverse_c_quizzes == reverse_result).long().min(dim=-1).values ) nb_correct.append((correct * reverse_correct)[None, :]) - nb_correct = torch.cat(nb_correct, dim=0) + nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0) # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") # with open(filename, "w") as f: # for k in nb_correct: # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0), summed_logits.mean() + return c_quizzes, nb_correct, summed_logits.mean() -- 2.39.5