From ffbc1fbdaef5351eb89ccd10562694491778f1e0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 15 Aug 2024 12:59:47 +0200 Subject: [PATCH] Update. --- main.py | 104 ++++++++++++++++++++++++++++++++++++++++++++++-- quiz_machine.py | 3 +- 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 4326491..4375985 100755 --- a/main.py +++ b/main.py @@ -452,7 +452,7 @@ c_quizzes_procedure = [ ###################################################################### -def save_additional_results(model, models, c_quizzes_procedure): +def save_additional_results(n_epoch, model, models, c_quizzes_procedure): # Save generated quizzes with the successive generation steps recorder = [] @@ -592,6 +592,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): proba_other_solutions = model_proba_solutions( model, solved_c_quizzes[s] ) + + # proba_other_solutions += torch.rand(proba_other_solutions.size()) * 1e-6 + proba_other_solutions[dont_get_this_quiz] = -1 # print( # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n" @@ -945,8 +948,101 @@ if args.dirty_debug: args.nb_new_c_quizzes_for_train = 100 args.nb_new_c_quizzes_for_test = 10 -if args.test == "gen": - save_additional_results(model, models, c_quizzes_procedure) +###################################################################### +###################################################################### + + +class Recorder(nn.Module): + def __init__(self, tape): + super().__init__() + self.tape = tape + + def forward(self, input): + self.tape.append(input) + return input + + +if args.test == "mlp": + model = models[0] + tape_input, tape_output = [], [] + L = len(model.trunk) + model.trunk.insert(L // 2 + 1, Recorder(tape_output)) + model.trunk.insert(L // 2, Recorder(tape_input)) + + print(model.trunk) + train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + + with torch.autograd.no_grad(): + model.to(main_device).eval() + for input in train_input.split(args.batch_size): + input = input.to(main_device) + output = model(mygpt.BracketedSequence(input)).x + + train_input = torch.cat([bs.x for bs in tape_input], dim=0) + train_targets = torch.cat([bs.x for bs in tape_output], dim=0) + + print(f"{train_input.size()=} {train_targets.size()=}") + + exit(0) + +###################################################################### +###################################################################### + +if args.test == "reject": + record = [] + + c_quizzes_procedure = [ + (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot), + (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), + (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), + (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), + (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), + (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), + (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), + ] + + while sum([x.size(0) for x in record]) < 64: + model = models[torch.randint(len(models), (1,)).item()] + c_quizzes = quiz_machine.generate_c_quizzes( + 64, + model_for_generation=model, + procedure=c_quizzes_procedure, + ) + + p = quiz_machine.models_logprobas( + model, + c_quizzes, + ("A", "f_A", "B", "f_B"), + (1, 1, 1, 1), + temperature=1, + ).exp() + + p_hot = quiz_machine.models_logprobas( + model, + c_quizzes, + ("A", "f_A", "B", "f_B"), + (1, 1, 1, 1), + temperature=args.temperature_hot, + ).exp() + + to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p + record.append(c_quizzes[to_keep]) + + print("NB_KEPT", sum([x.size(0) for x in record])) + + filename = f"sampling_with_rejection.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=c_quizzes, + ) + + log_string(f"wrote {filename}") + exit(0) ###################################################################### @@ -1018,7 +1114,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"wrote {filename}") for model in weakest_models: - save_additional_results(model, models, c_quizzes_procedure) + save_additional_results(n_epoch, model, models, c_quizzes_procedure) ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 1fe2e94..0bdaaec 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -294,6 +294,7 @@ class QuizMachine: struct, mask_loss, mask_noise=None, + temperature=1.0, device=None, ): if device is None: @@ -323,7 +324,7 @@ class QuizMachine: quiz_mask_loss = self.make_quiz_mask( input, struct=struct, mask=mask_loss ) - output = model(mygpt.BracketedSequence(input)).x + output = model(mygpt.BracketedSequence(input)).x / temperature l[...] = ( -F.cross_entropy(output.transpose(1, 2), input, reduction="none") * quiz_mask_loss -- 2.39.5