From f7a8cd141a39039048d3e9311220a33079f2cfc7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 9 Jul 2024 23:00:50 +0200 Subject: [PATCH] Update. --- main.py | 14 +++++--------- problem.py | 6 ++++++ quiz_machine.py | 19 +++++++++++++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 3004f9c..57f79a3 100755 --- a/main.py +++ b/main.py @@ -392,7 +392,7 @@ def run_tests(model, quiz_machine, deterministic_synthesis): def standard_validity(logproba): l = logproba.sort(dim=-1).values - return logical_and(l[0] < math.log(0.5), l[1] > math.log(0.95)) + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95)) def valid_c_quizzes(recorded, criteria): @@ -435,17 +435,10 @@ def create_c_quizzes( c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] if c_quizzes.size(0) > 0: - logproba = c_quizzes.new(c_quizzes.size(0), len(models)) - for q, l in zip( - c_quizzes.split(args.batch_size), logproba.split(args.batch_size) - ): - for model in models: - l[model.id] = F.cross_entropy(model(q)) - + logproba = quiz_machine.logproba_solution(models, c_quizzes) for l in logproba: s = " ".join([str(x.item()) for x in l]) logp_file.write(s + "\n") - quizzes_and_logproba_records.append((c_quizzes, logproba)) nb_validated = valid_c_quizzes( @@ -655,6 +648,9 @@ for n_epoch in range(args.nb_epochs): ################################################## # Replace a fraction of the w_quizzes with fresh ones + log_string( + f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes" + ) quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) ################################################## diff --git a/problem.py b/problem.py index eceb904..617b2a8 100755 --- a/problem.py +++ b/problem.py @@ -19,6 +19,12 @@ class Problem: else: self.queue = None + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + def nb_token_values(self): pass diff --git a/quiz_machine.py b/quiz_machine.py index 321df35..c1477c9 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -416,6 +416,25 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes) + def logproba_solution(self, models, c_quizzes): + logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + + for model in models: + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = ce.sum(dim=-1) + + return logproba + + ############################################################### + def compute_correctness( self, c_quizzes, -- 2.39.5