From 334d401c6f04003b84e9d2a35789e070fa8b8cb7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jul 2024 16:21:24 +0200 Subject: [PATCH] Update. --- grids.py | 39 ++++++++++++++++++++++++++++--------- main.py | 3 +++ quiz_machine.py | 51 +++++++++++++++++++++++++++++++++++-------------- 3 files changed, 70 insertions(+), 23 deletions(-) diff --git a/grids.py b/grids.py index e1eff00..4db12db 100755 --- a/grids.py +++ b/grids.py @@ -118,7 +118,7 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def make_ar_mask(self, quizzes, first=False): + def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"): S = self.height * self.width assert ( @@ -133,12 +133,17 @@ class Grids(problem.Problem): T = torch.arange(quizzes.size(1), device=quizzes.device) - if first: + if shape == "fwd_3_bck_123": + forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() + backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long() + elif shape == "fwd_012_bck_0": forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long() backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long() - else: + elif shape == "fwd_3_bck_3": forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long() + backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() + else: + raise ValueError(shape) is_forward = (quizzes[:, 0] == self.token_forward).long() @@ -147,7 +152,7 @@ class Grids(problem.Problem): + (1 - is_forward)[:, None] * backward_mask[None, :] ) - def p_a_flip(self, quizzes): + def p_a_flip(self, quizzes, pairwise_flip=False): S = self.height * self.width assert ( @@ -160,10 +165,26 @@ class Grids(problem.Problem): & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)]) ).all() - flipped = torch.cat( - [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)], - dim=1, - ) + if pairwise_flip: + flipped = torch.cat( + [ + quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1], + quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1], + quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1], + quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], + ], + dim=1, + ) + else: + flipped = torch.cat( + [ + quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1], + quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], + quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1], + quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1], + ], + dim=1, + ) m = (flipped[:, 0] == self.token_forward).long() flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward diff --git a/main.py b/main.py index 0182e6a..562a95d 100755 --- a/main.py +++ b/main.py @@ -425,10 +425,13 @@ def keep_good_quizzes(models, quizzes): elif args.c_quiz_validation_mode == "predict": nc = quiz_machine.solution_nb_correct(models, quizzes) + count_nc = tuple( n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) ) + log_string(f"nb_correct {count_nc}") + to_keep = nc == (len(models) - 1) else: diff --git a/quiz_machine.py b/quiz_machine.py index 046ab73..91eb3ac 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -293,7 +293,7 @@ class QuizMachine: def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis): def compute_accuracy(input, log_prefix=None): input = input.to(self.device) - ar_mask = self.problem.make_ar_mask(input) + ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123") result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -432,7 +432,7 @@ class QuizMachine: c_quizzes.split(self.batch_size), logproba.split(self.batch_size) ): input = input.to(self.device) - ar_mask = self.problem.make_ar_mask(input) + ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123") output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( @@ -448,10 +448,7 @@ class QuizMachine: ############################################################### def solution_nb_correct( - self, - models_for_validation, - c_quizzes, - deterministic_validation=False, + self, models_for_validation, c_quizzes, bidirectional_validation=True ): seq_logproba = torch.zeros( c_quizzes.size(0), @@ -464,10 +461,11 @@ class QuizMachine: seq_logproba[...] = 0.0 for model in models_for_validation: + # A, f(A), B | f(B) c_quizzes = c_quizzes.to(self.device) result = c_quizzes.clone() - ar_mask = self.problem.make_ar_mask(result) + ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3") masked_inplace_autoregression( model=model, @@ -476,13 +474,38 @@ class QuizMachine: ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], temperature=1.0, - deterministic_synthesis=deterministic_validation, + deterministic_synthesis=False, device=self.device, ) correct = (c_quizzes == result).long().min(dim=-1).values - nb_correct += correct + # ------------------------------- + + # f(A), A, f(B) | B + c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to( + self.device + ) + result = c_quizzes.clone() + + ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3") + + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba[:, model.id], + temperature=1.0, + deterministic_synthesis=False, + device=self.device, + ) + + flipped_correct = (c_quizzes == result).long().min(dim=-1).values + + # ------------------------------- + + nb_correct += correct * flipped_correct return nb_correct.to("cpu") @@ -512,7 +535,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, first=True), + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, temperature=temperature_hot, deterministic_synthesis=False, @@ -523,7 +546,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, @@ -537,7 +560,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, first=True), + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, temperature=temperature_hot, deterministic_synthesis=False, @@ -548,7 +571,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, @@ -561,7 +584,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, -- 2.20.1