From 6b3d2f07d7d29957a9a7c4d379f167c61500d7dd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 15:53:01 +0200 Subject: [PATCH] Update. --- main.py | 142 +++++++++++++++----------------------------------------- 1 file changed, 37 insertions(+), 105 deletions(-) diff --git a/main.py b/main.py index dda62af..5ef583c 100755 --- a/main.py +++ b/main.py @@ -126,13 +126,6 @@ parser.add_argument( help="A comma-separated subset of: " + grids_tasks + ".", ) -parser.add_argument( - "--grids_science_tasks", - type=str, - default=None, - help="A comma-separated subset of: " + grids_tasks + ", or None.", -) - ###################################################################### parser.add_argument("--sky_height", type=int, default=6) @@ -152,14 +145,6 @@ args = parser.parse_args() if args.result_dir is None: args.result_dir = f"results_culture" -assert not args.grids_science_tasks or ( - len( - set(args.grids_world_tasks.split(",")) - & set(args.grids_science_tasks.split(",")) - ) - == 0 -), "World and science tasks have to be disjoint" - ###################################################################### default_args = { @@ -302,43 +287,12 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 -if args.problem == "sky": - problem = sky.Sky( - height=args.sky_height, - width=args.sky_width, - nb_birds=args.sky_nb_birds, - nb_iterations=args.sky_nb_iterations, - speed=args.sky_speed, - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, - nb_threads=args.nb_threads, - ) - -elif args.problem == "grids": - problem = grids.Grids( - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, - nb_threads=args.nb_threads, - tasks=args.grids_world_tasks, - ) - - if args.grids_science_tasks is None: - science_w_quizzes = None - else: - science_problem = grids.Grids( - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, - nb_threads=args.nb_threads, - tasks=args.grids_science_tasks, - ) - science_w_quizzes = science_problem.generate_w_quizzes(100) - - if not args.resume: - science_problem.save_some_examples(args.result_dir, "science_") - - -else: - raise ValueError +problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_world_tasks, +) if not args.resume: problem.save_some_examples(args.result_dir) @@ -516,7 +470,7 @@ c_quizzes_procedure = [ ###################################################################### -def save_additional_results(model, models, science_w_quizzes): +def save_additional_results(model, models): # Save generated quizzes with the successive steps recorder = [] @@ -576,49 +530,6 @@ def save_additional_results(model, models, science_w_quizzes): log_string(f"wrote {filename}") - ###################################################################### - - if science_w_quizzes is not None: - struct = ("A", "f_A", "B", "f_B") - mask = (0, 0, 0, 1) - result, correct, _ = quiz_machine.predict( - model=model, - quizzes=science_w_quizzes.to(main_device), - struct=struct, - mask=mask, - ) - - predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand( - correct.size(0), -1 - ) - correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() - - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() - - log_string( - f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" - ) - - i = correct == 1 - j = correct != 1 - - result = torch.cat([result[i], result[j]], dim=0) - correct = torch.cat([correct[i], correct[j]], dim=0) - correct_parts = predicted_parts * correct[:, None] - - result = result[:128] - predicted_parts = predicted_parts[:128] - correct_parts = correct_parts[:128] - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"culture_science_{n_epoch:04d}_{model.id:02d}.png", - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) - ###################################################################### @@ -642,6 +553,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): for model in models: model.recorded_c_quizzes = [] + teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64) + while nb_validated < nb_to_validate: model_for_generation = models[torch.randint(len(models), (1,)).item()] @@ -696,6 +609,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): proba_other_solutions[dont_get_it] = -1 i = proba_other_solutions.argmax() model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) + teaching_count[i, model.id] += 1 nb_validated += 1 duration = time.perf_counter() - start_time @@ -715,6 +629,10 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)" ) + for s in range(teaching_count.size(0)): + o = [x.item() for x in teaching_count[s]] + log_string(f"teacher model {s} to {o}") + for model in models: new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0) @@ -723,15 +641,29 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): if n > 0: model.train_c_quiz_bags.append(new_bag[:n]) if n < new_bag.size(0): - model.test_c_quiz_bags.append(new_bag[:n]) + model.test_c_quiz_bags.append(new_bag[n:]) - vq = new_bag[:128] + c_quizzes = new_bag[:128] - seq_logprobas = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + l = [ + quiz_machine.models_logprobas( + model, + c_quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + (0, 0, 1, 0), + ) + + quiz_machine.models_logprobas( + model, + c_quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), + (0, 0, 1, 0), + ) + for model in models + ] + + seq_logprobas = torch.cat([x[:, None] for x in l], dim=1) probas = seq_logprobas.exp() @@ -744,7 +676,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): filename = f"culture_c_quiz_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, vq, comments=comments + args.result_dir, filename, c_quizzes, comments=comments ) log_string( @@ -916,7 +848,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"wrote {filename}") for model in weakest_models: - save_additional_results(model, models, science_w_quizzes) + save_additional_results(model, models) ###################################################################### -- 2.39.5