From: François Fleuret Date: Sat, 20 Jul 2024 21:47:51 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=44b760ab5e7b6ff266f2f097ecbfb48bd1b278b5;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 4db12db..bbb18d2 100755 --- a/grids.py +++ b/grids.py @@ -135,10 +135,10 @@ class Grids(problem.Problem): if shape == "fwd_3_bck_123": forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long() + backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long() elif shape == "fwd_012_bck_0": forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long() + backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long() elif shape == "fwd_3_bck_3": forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() @@ -1277,7 +1277,6 @@ class Grids(problem.Problem): S = self.height * self.width Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1] f_Bs = answers[:, 1:] - print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}") return (Bs == f_Bs).long().min(dim=-1).values > 0 def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): @@ -1371,8 +1370,8 @@ if __name__ == "__main__": nb, nrow = 8, 2 # nb, nrow = 8, 2 - for t in grids.all_tasks: - # for t in [grids.task_compute]: + # for t in grids.all_tasks: + for t in [grids.task_convex]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size()) diff --git a/main.py b/main.py index 562a95d..c9c30c3 100755 --- a/main.py +++ b/main.py @@ -558,6 +558,8 @@ for k in range(args.nb_gpts): ###################################################################### +current_epoch = 0 + if args.resume: try: for model in models: @@ -580,6 +582,15 @@ if args.resume: log_string(f"cannot find {filename}") pass + try: + filename = "state.pth" + state = torch.load(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + current_epoch = state["current_epoch"] + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + except: log_string(f"error when loading {filename}.") exit(1) @@ -616,7 +627,7 @@ if args.dirty_debug: ###################################################################### -for n_epoch in range(args.nb_epochs): +for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"--- epoch {n_epoch} ----------------------------------------") cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) @@ -675,6 +686,11 @@ for n_epoch in range(args.nb_epochs): ) log_string(f"wrote {filename}") + state = {"current_epoch": n_epoch} + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + # Renew the training samples for model in weakest_models: diff --git a/quiz_machine.py b/quiz_machine.py index 91eb3ac..c006ea4 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -447,9 +447,7 @@ class QuizMachine: ############################################################### - def solution_nb_correct( - self, models_for_validation, c_quizzes, bidirectional_validation=True - ): + def solution_nb_correct(self, models_for_validation, c_quizzes): seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, @@ -457,6 +455,12 @@ class QuizMachine: ) nb_correct = 0 + correct_models = torch.empty( + c_quizzes.size(0), + max([m.id for m in models_for_validation]) + 1, + device=self.device, + dtype=torch.int64, + ) seq_logproba[...] = 0.0 @@ -478,7 +482,9 @@ class QuizMachine: device=self.device, ) - correct = (c_quizzes == result).long().min(dim=-1).values + correct_models[:, model.id] = ( + (c_quizzes == result).long().min(dim=-1).values + ) # ------------------------------- @@ -501,13 +507,17 @@ class QuizMachine: device=self.device, ) - flipped_correct = (c_quizzes == result).long().min(dim=-1).values + correct_models[:, model.id] *= ( + (c_quizzes == result).long().min(dim=-1).values + ) # ------------------------------- - nb_correct += correct * flipped_correct + i = correct_models.sum(dim=1) == correct_models.size(1) - 1 + c = (correct_models[i] == 0).long().sum(dim=0) + self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}") - return nb_correct.to("cpu") + return correct_models.sum(dim=1).to("cpu") ###############################################################