From 316b1147f2ad0755ba4a5acf50f92afc319135c3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 Jul 2024 15:28:11 +0200 Subject: [PATCH] Update. --- grids.py | 28 ++++++++++++++++------------ main.py | 36 +++++++++++++++++++----------------- quiz_machine.py | 4 ++-- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/grids.py b/grids.py index eea8c6c..7752136 100755 --- a/grids.py +++ b/grids.py @@ -143,7 +143,7 @@ class Grids(problem.Problem): self.task_scale, self.task_symbols, self.task_isometry, - # self.task_islands, + self.task_islands, ] if tasks is None: @@ -617,8 +617,8 @@ class Grids(problem.Problem): while True: error = False - N = torch.randint(5, (1,)).item() + 1 - c = torch.zeros(N + 1) + N = torch.randint(5, (1,)).item() + 2 + c = torch.zeros(N + 1, dtype=torch.int64) c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -635,18 +635,20 @@ class Grids(problem.Problem): X[...] = self.cache_count.pop() - k = (X.max() + 1 + (c.size(0) - 1)).item() - V = torch.arange(k) // (c.size(0) - 1) - V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( - c.size(0) - 1 - ) + 1 + # k = (X.max() + 1 + (c.size(0) - 1)).item() + # V = torch.arange(k) // (c.size(0) - 1) + # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( + # c.size(0) - 1 + # ) + 1 + V = torch.randint(c.size(0) - 1, (X.max() + 1,)) + 1 V[0] = 0 + NB = F.one_hot(c[V]).sum(dim=0) X[...] = c[V[X]] if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1: f_X[...] = 0 for e in range(1, N + 1): - for j in range((X == c[e]).sum() + 1): + for j in range(NB[c[e]]): if j < self.width: f_X[e - 1, j] = c[e] else: @@ -659,6 +661,8 @@ class Grids(problem.Problem): if not error: break + assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3 + # @torch.compile def task_trajectory(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:2] + 1 @@ -1155,19 +1159,19 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_distance]: + for t in [grids.task_count]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) grids.save_quiz_illustrations( "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - # exit(0) + exit(0) nb = 1000 # for t in grids.all_tasks: - for t in [grids.task_distance]: + for t in [grids.task_count]: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/main.py b/main.py index cdaacdf..07fec96 100755 --- a/main.py +++ b/main.py @@ -441,32 +441,34 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): duration = time.perf_counter() - start_time - if nb_validated > 0: - e = (nb_to_create - nb_validated) * duration / nb_validated - if e > 0: - e = "~" + str(datetime.timedelta(seconds=int(e))) - else: - e = "0s" + if nb_validated > 0 and nb_validated < nb_to_create: + d = (nb_to_create - nb_validated) * duration / nb_validated else: - e = "???" + d = 0 + + e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( + "%a %H:%M" + ) log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (remaining time {e})" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finish {e})" ) # store the new c_quizzes which have been validated - quiz_machine.reverse_random_half_in_place(validated_quizzes) - quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes( - validated_quizzes[nb_for_train:nb_to_create], for_train=False - ) + v_train = validated_quizzes[:nb_for_train] + quiz_machine.store_c_quizzes(v_train, for_train=True) + quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_train), for_train=True) + + v_test = validated_quizzes[nb_for_train:nb_to_create] + quiz_machine.store_c_quizzes(v_test, for_train=False) + quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False) ###################################################################### # save images with their logprobas - vq = validated_quizzes[:72] - vl = validated_logprobas[:72] + vq = validated_quizzes[:128] + vl = validated_logprobas[:128] if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" @@ -591,10 +593,10 @@ if args.max_percents_of_test_in_train >= 0: ###################################################################### if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50 + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100 if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50 + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100 log_string( f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" diff --git a/quiz_machine.py b/quiz_machine.py index 927a349..bcb89ec 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -415,8 +415,8 @@ class QuizMachine: self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:72], - mistakes=test_correct[:72] * 2 - 1, + quizzes=test_result[:128], + mistakes=test_correct[:128] * 2 - 1, ) return main_test_accuracy -- 2.20.1