From 4a63b2b44bc08cb04b236b35a3d36aa242912d48 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 11 Jul 2024 17:37:46 +0200 Subject: [PATCH] Update. --- main.py | 18 ++++++++---------- quiz_machine.py | 16 ++++++++++------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 73e7ca2..4cf4d59 100755 --- a/main.py +++ b/main.py @@ -341,7 +341,7 @@ def one_epoch(model, quiz_machine, local_device=None): train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - log_string(f"train_perplexity {n_epoch} {train_perplexity}") + log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}") run_tests(model, quiz_machine, deterministic_synthesis=False) @@ -354,9 +354,6 @@ def one_epoch(model, quiz_machine, local_device=None): def standard_validity(logproba): l = logproba.sort(dim=-1).values return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) - # warnings.warn("TEST!!!", RuntimeWarning) - # print(l.exp()) - # return (l[:, 0] < math.log(0.99)) def valid_c_quizzes(recorded, criteria): @@ -452,13 +449,9 @@ for k in range(args.nb_gpts): model.id = k model.TRAINING_LOCK = threading.Lock() - model.train_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_train_samples - ).to(device) + model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) - model.test_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_test_samples - ).to(device) + model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) models.append(model) @@ -532,6 +525,11 @@ if args.dirty_debug: nb_new_c_quizzes_for_train = 100 nb_new_c_quizzes_for_test = 10 + def standard_validity(logproba): + l = logproba.sort(dim=-1).values + return l[:, 0] < math.log(0.5) + + ###################################################################### for n_epoch in range(args.nb_epochs): diff --git a/quiz_machine.py b/quiz_machine.py index 1f1046d..ae14614 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -327,6 +327,7 @@ class QuizMachine: self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input, log_prefix=None): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -404,26 +405,29 @@ class QuizMachine: input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + input[-nb:] = fresh_w_quizzes.to("cpu") ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): with self.LOCK_C_QUIZZES: if for_train: - self.train_c_quizzes.append(new_c_quizzes) + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) else: - self.test_c_quizzes.append(new_c_quizzes) + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) ###################################################################### def logproba_of_solutions(self, models, c_quizzes): - logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + logproba = c_quizzes.new_zeros( + c_quizzes.size(0), len(models), device=self.device + ) for model in models: for input, l in zip( c_quizzes.split(self.batch_size), logproba.split(self.batch_size) ): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) output = model(mygpt.BracketedSequence(input)).x ce = ( @@ -432,7 +436,7 @@ class QuizMachine: ) l[:, model.id] = -ce.sum(dim=-1) - return logproba + return logproba.to("cpu") ############################################################### @@ -561,4 +565,4 @@ class QuizMachine: device=self.device, ) - return c_quizzes + return c_quizzes.to("cpu") -- 2.39.5