From c427907107a4941caa15c44516999ed2c507fa0c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 22:51:17 +0200 Subject: [PATCH] Update. --- grids.py | 74 +++++++++++++++++++++++-------------------------- main.py | 10 +++++-- quiz_machine.py | 34 +++++++++++++++++------ 3 files changed, 66 insertions(+), 52 deletions(-) diff --git a/grids.py b/grids.py index ee3a1e6..25bbc80 100755 --- a/grids.py +++ b/grids.py @@ -138,25 +138,30 @@ class Grids(problem.Problem): return struct def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): + if torch.is_tensor(quizzes): + return self.reconfigure([quizzes])[0] + S = self.height * self.width - result = quizzes.new(quizzes.size()) + result = [x.new(x.size()) for x in quizzes] - struct_from = self.get_structure(quizzes[:1]) - i = self.indices_select(quizzes, struct_from) + struct_from = self.get_structure(quizzes[0][:1]) + i = self.indices_select(quizzes[0], struct_from) sf = dict((l, n) for n, l in enumerate(struct_from)) - q = quizzes.reshape(-1, 4, S + 1)[i] - - result[i, 0 * (S + 1) : 1 * (S + 1)] = q[:, sf[struct[0]], :] - result[i, 1 * (S + 1) : 2 * (S + 1)] = q[:, sf[struct[1]], :] - result[i, 2 * (S + 1) : 3 * (S + 1)] = q[:, sf[struct[2]], :] - result[i, 3 * (S + 1) : 4 * (S + 1)] = q[:, sf[struct[3]], :] + for q in range(4): + k = sf[struct[q]] + for x, y in zip(quizzes, result): + l = x.size(1) // 4 + y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l] j = i == False if j.any(): - result[j] = self.reconfigure(quizzes[j], struct=struct) + for z, y in zip( + self.reconfigure([x[j] for x in quizzes], struct=struct), result + ): + y[j] = z return result @@ -303,6 +308,7 @@ class Grids(problem.Problem): margin=8, ): quizzes = quizzes.to("cpu") + self.check_structure(quizzes, ("A", "f_A", "B", "f_B")) S = self.height * self.width @@ -339,8 +345,9 @@ class Grids(problem.Problem): colors = ( predicted_parts[:, :, None] * ( - correct_parts[:, :, None] * green[None, None, :] - + (1 - correct_parts[:, :, None]) * red[None, None, :] + (correct_parts[:, :, None] == 1).long() * green[None, None, :] + + (correct_parts[:, :, None] == 0).long() * gray[None, None, :] + + (correct_parts[:, :, None] == -1).long() * red[None, None, :] ) + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) @@ -1321,21 +1328,18 @@ class Grids(problem.Problem): X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n] - while True: - i1, i2 = torch.randint(self.height, (2,)) - j1, j2 = torch.randint(self.width, (2,)) - if ( - abs(i1 - i2) + abs(j1 - j2) > 2 - and X[i1, j1] == 0 - and X[i2, j2] == 0 - ): - break - - d2 = self.compdist(X, i2, j2) - d = self.compdist(X, i1, j1) + i1, i2 = torch.randint(self.height, (2,)) + j1, j2 = torch.randint(self.width, (2,)) + if ( + abs(i1 - i2) + abs(j1 - j2) > 2 + and X[i1, j1] == 0 + and X[i2, j2] == 0 + ): + d2 = self.compdist(X, i2, j2) + d = self.compdist(X, i1, j1) - if d2[i1, j1] < 2 * self.width: - break + if d2[i1, j1] < 2 * self.width: + break m = ((d + d2) == d[i2, j2]).long() f_X[...] = m * c[-1] + (1 - m) * f_X @@ -1491,32 +1495,22 @@ if __name__ == "__main__": nb, nrow = 128, 4 # nb, nrow = 8, 2 - for t in grids.all_tasks: - # for t in [grids.task_replace_color]: - # for t in [grids.task_symbols]: + # for t in grids.all_tasks: + for t in [grids.task_path]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) - # predicted_parts = quizzes.new_zeros(quizzes.size(0), 4) - # predicted_parts[:, 3] = torch.randint( - # 2, (quizzes.size(0),), device=quizzes.device - # ) - # predicted_parts[:, :3] = 1 - predicted_parts[:, 3:] - # correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device) - # correct_parts[:, 1:2] = correct_parts[:, :1] grids.save_quizzes_as_image( "/tmp", t.__name__ + ".png", quizzes, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, ) # exit(0) nb = 1000 - for t in grids.all_tasks: - # for t in [grids.task_compute]: + # for t in grids.all_tasks: + for t in [grids.task_path]: start_time = time.perf_counter() w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/main.py b/main.py index deba848..fa33b4e 100755 --- a/main.py +++ b/main.py @@ -89,7 +89,7 @@ parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--max_fail_to_validate", type=int, default=1) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) parser.add_argument("--proba_understands", type=float, default=0.9) @@ -99,7 +99,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=0.75) -parser.add_argument("--nb_rounds", type=int, default=3) +parser.add_argument("--nb_rounds", type=int, default=1) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -549,7 +549,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 v = " ".join([str(n.item()) for n in r]) f.write(f"{n}: {v}\n") - quiz_machine.save_quizzes_as_image(args.result_dir, prefix, vq) + vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B")) + quiz_machine.problem.save_quizzes_as_image(args.result_dir, prefix, vq) ###################################################################### @@ -724,6 +725,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): temperature_cold=args.temperature_cold, ) + c_quizzes = quiz_machine.problem.reconfigure( + c_quizzes, ("A", "f_A", "B", "f_B") + ) quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}.png", diff --git a/quiz_machine.py b/quiz_machine.py index 749ae8b..8f14fa0 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -200,7 +200,7 @@ class QuizMachine: device=self.device, ) - correct = (result == quizzes).min(dim=1).values + correct = (result == quizzes).min(dim=1).values.long() return result, correct @@ -213,7 +213,7 @@ class QuizMachine: ): input = input.to(self.device) result = input.new(input.size()) - correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool) + correct = input.new(input.size(0)) predicted_parts = input.new(input.size(0), 4) nb = 0 for struct, mask in [ @@ -226,19 +226,40 @@ class QuizMachine: result[i], correct[i] = self.predict( model=model, quizzes=input[i], struct=struct, mask=mask ) + predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :] + correct[i] = (2 * correct[i] - 1) * ( + predicted_parts[i].sum(dim=-1) == 1 + ).long() assert nb == input.size(0) - main_test_accuracy = correct.sum() / correct.size(0) + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + self.logger( + f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" + ) + + main_test_accuracy = nb_correct / nb_total ############################## + correct_parts = predicted_parts * correct[:, None] + + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] + + self.problem.reconfigure( + [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B") + ) + self.problem.save_quizzes_as_image( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", - quizzes=result[:128], + quizzes=result, predicted_parts=predicted_parts, + correct_parts=correct_parts, ) return main_test_accuracy @@ -437,11 +458,6 @@ class QuizMachine: lt_noisy = lambda s, logits: logits / temperature_hot lt_clean = lambda s, logits: logits / temperature_cold - # lt_noisy = lambda s, logits: logits / ( - # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long() - # ) - # lt_clean = None - masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, -- 2.39.5