From f5a82bf034c425f80d786f5d069cdb16f47b9a44 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 16:42:24 +0200 Subject: [PATCH] Update. --- grids.py | 34 ++++++++++++---------- main.py | 13 ++------- problem.py | 75 ++++++++++++++++++++++++------------------------- quiz_machine.py | 66 ++++++++++++++++++++++++------------------- 4 files changed, 96 insertions(+), 92 deletions(-) diff --git a/grids.py b/grids.py index 99a9240..37ed6a0 100755 --- a/grids.py +++ b/grids.py @@ -204,10 +204,11 @@ class Grids(problem.Problem): self.token_f_B: "f_B", } - self.nb_token_values = self.token_f_B + 1 - self.height = 10 self.width = 10 + self.seq_len = 4 * (1 + self.height * self.width) + self.nb_token_values = self.token_f_B + 1 + self.cache_rec_coo = {} all_tasks = [ @@ -1378,27 +1379,30 @@ class Grids(problem.Problem): ###################################################################### + def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")): + S = self.height * self.width + quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64) + quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0]) + quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1]) + quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2]) + quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3]) + + return quizzes + def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): if tasks is None: tasks = self.all_tasks - S = self.height * self.width - quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64) + quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) if progress_bar: quizzes = tqdm.tqdm( quizzes, dynamic_ncols=True, desc="world quizzes generation", - total=prompts.size(0), + total=quizzes.size(0), ) - quizzes[...] = 0 - quizzes[:, 0 * (S + 1)] = self.token_A - quizzes[:, 1 * (S + 1)] = self.token_f_A - quizzes[:, 2 * (S + 1)] = self.token_B - quizzes[:, 3 * (S + 1)] = self.token_f_B - for quiz in quizzes: q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width) q[...] = 0 @@ -1412,9 +1416,9 @@ class Grids(problem.Problem): nb, nrow = 128, 4 for t in self.all_tasks: print(t.__name__) - prompts, answers = self.generate_w_quizzes_(nb, tasks=[t]) + quizzes = self.generate_w_quizzes_(nb, tasks=[t]) self.save_quizzes_as_image( - result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow + result_dir, t.__name__ + ".png", quizzes, nrow=nrow ) @@ -1499,9 +1503,9 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quiz_illustrations( + grids.save_quizzes_as_image( "/tmp", - "test", + "test.png", prompts[:nb], answers[:nb], # You can add a bool to put a frame around the predicted parts diff --git a/main.py b/main.py index 122dd31..fcca116 100755 --- a/main.py +++ b/main.py @@ -103,8 +103,6 @@ parser.add_argument("--nb_rounds", type=int, default=3) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") -parser.add_argument("--p2a_only", action="store_true", default=False) - parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -394,11 +392,9 @@ def one_epoch(model, quiz_machine, local_device=main_device): acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - n_p2a = input[:, 0] == quiz_machine.problem.token_forward - to_store = from_w & n_p2a.to("cpu") - if to_store.any(): + if from_w.any(): hard_w_quizzes.append( - (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu")) + (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) ) nb_train_samples += input.size(0) @@ -452,7 +448,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, model_for_generation=model_for_generation, - p2a_only=args.p2a_only, temperature_hot=args.temperature_hot, temperature_cold=args.temperature_cold, ) @@ -585,7 +580,6 @@ for k in range(args.nb_gpts): model=model, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - p2a_only=args.p2a_only, ) models.append(model) @@ -729,7 +723,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): c_quizzes = quiz_machine.generate_c_quizzes( 128, model_for_generation=model, - p2a_only=args.p2a_only, temperature_hot=args.temperature_hot, temperature_cold=args.temperature_cold, ) @@ -741,7 +734,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # Renew the training samples for model in weakest_models: - quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only) + quiz_machine.renew_train_w_quizzes(model=model) if args.log_command is not None: s = args.log_command.split() diff --git a/problem.py b/problem.py index 61e4834..50376d6 100755 --- a/problem.py +++ b/problem.py @@ -25,46 +25,23 @@ class Problem: else: return self.queue.qsize() * self.chunk_size - def nb_token_values(self): - pass - - def trivial_prompts_and_answers(self, prompts, answers): - pass - - # The one to implement, returns two tensors nb x D and nb x D' - def generate_w_quizzes_(self, nb): - pass - - # save a file to vizualize quizzes, you can save a txt or png file - def save_quiz_illustrations( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - pass - def fill_cache(self): while True: - prompts, answers = self.generate_w_quizzes_(self.chunk_size) - - self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) + quizzes = self.generate_w_quizzes_(self.chunk_size) + self.queue.put(quizzes.to("cpu"), block=True) def generate_w_quizzes(self, nb): if self.queue is None: return self.generate_w_quizzes_(nb) if self.rest is not None: - prompts, answers = rest + quizzes = rest else: - prompts, answers = [], [] + quizzes = [] self.rest = None - n = sum([p.size(0) for p in prompts]) + n = sum([q.size(0) for q in quizzes]) with tqdm.tqdm( total=nb, @@ -72,22 +49,44 @@ class Problem: desc="world generation", ) as pbar: while n < nb: - p, s = self.queue.get(block=True) - prompts.append(p) - answers.append(s) - n += p.size(0) - pbar.update(p.size(0)) + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) - prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) - assert n == prompts.size(0) + quizzes = torch.cat(quizzes, dim=0) + assert n == quizzes.size(0) k = n - nb if k > 0: - rest = (prompts[-k:], answers[-k:]) - prompts, answers = prompts[:-k], answers[:-k] + rest = quizzes[-k:] + quizzes = quizzes[:-k] - return prompts, answers + return quizzes + + ###################################################################### + + def trivial_prompts_and_answers(self, prompts, answers): + pass + + # The one to implement, returns two tensors nb x D and nb x D' + def generate_w_quizzes_(self, nb): + pass + + # save a file to vizualize quizzes, you can save a txt or png file + def save_quiz_illustrations( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + pass def save_some_examples(self, result_dir): pass + + ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index bc2a358..bb62181 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -174,36 +174,36 @@ class QuizMachine: ###################################################################### - def produce_results( - self, n_epoch, model, input, result_dir, deterministic_synthesis - ): - def predict(input, struct, mask): - ar_mask = self.problem.make_ar_mask( - quizzes=quizzes, struct=struct, mask=mask - ) - result = quizzes * (1 - ar_mask) - seq_logproba = torch.empty(fwd_quizzes, device=self.device) + def predict(self, input, struct, mask): + ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) + result = quizzes * (1 - ar_mask) - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc="accuracy", - device=self.device, - ) + seq_logproba = torch.empty(fwd_quizzes, device=self.device) - nb_correct = (result == quizzes).min(dim=1).long() + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + deterministic_synthesis=deterministic_synthesis, + progress_bar_desc="accuracy", + device=self.device, + ) + + nb_correct = (result == quizzes).min(dim=1).long() - return result, correct + return result, correct + def produce_results( + self, n_epoch, model, input, result_dir, deterministic_synthesis + ): input = input.to(self.device) i = self.problem.indices_select(quizzes=input, struct=struct) + input_fwd = input[i] test_result_fwd, test_correct_fwd = predict( - input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) ) input_bck = self.problem.reconfigure( @@ -211,8 +211,9 @@ class QuizMachine: struct=("A", "f_A", "B", "f_B"), ) - l = input_bck.size(1) + l = input_bck.size(1) // 4 input_bck[:, 3 * l :] = input[i == False][:, :l] + test_result_bck, test_correct_bck = predict( input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) ) @@ -221,11 +222,14 @@ class QuizMachine: ############################## + test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0) + test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0) + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:128], - mistakes=test_correct[:128] * 2 - 1, + quizzes=test_result, + # mistakes=test_correct, ) return main_test_accuracy @@ -233,12 +237,16 @@ class QuizMachine: ###################################################################### def flip_half_in_place(self, quizzes): - r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5 - i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B")) + r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5 + i = self.problem.indices_select( + quizzes=quizzes, struct=("A", "f_A", "B", "f_B") + ) quizzes[i & r] = self.problem.reconfigure( quizzes[i & r], struct=("f_B", "f_A", "B", "A") ) - j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A")) + j = self.problem.indices_select( + quizzes=quizzes, struct=("f_B", "f_A", "B", "A") + ) quizzes[j & r] = self.problem.reconfigure( quizzes[j & r], struct=("A", "f_A", "B", "f_B") ) @@ -403,7 +411,7 @@ class QuizMachine: ): c_quizzes = torch.empty( nb, - self.prompt_len + self.answer_len, + self.problem.seq_len, device=self.device, dtype=torch.int64, ) -- 2.20.1