######################################################################
-def save_additional_results(model, models, c_quizzes_procedure):
+def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
# Save generated quizzes with the successive generation steps
recorder = []
proba_other_solutions = model_proba_solutions(
model, solved_c_quizzes[s]
)
+
+ # proba_other_solutions += torch.rand(proba_other_solutions.size()) * 1e-6
+
proba_other_solutions[dont_get_this_quiz] = -1
# print(
# f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n"
args.nb_new_c_quizzes_for_train = 100
args.nb_new_c_quizzes_for_test = 10
-if args.test == "gen":
- save_additional_results(model, models, c_quizzes_procedure)
+######################################################################
+######################################################################
+
+
+class Recorder(nn.Module):
+ def __init__(self, tape):
+ super().__init__()
+ self.tape = tape
+
+ def forward(self, input):
+ self.tape.append(input)
+ return input
+
+
+if args.test == "mlp":
+ model = models[0]
+ tape_input, tape_output = [], []
+ L = len(model.trunk)
+ model.trunk.insert(L // 2 + 1, Recorder(tape_output))
+ model.trunk.insert(L // 2, Recorder(tape_input))
+
+ print(model.trunk)
+ train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
+ with torch.autograd.no_grad():
+ model.to(main_device).eval()
+ for input in train_input.split(args.batch_size):
+ input = input.to(main_device)
+ output = model(mygpt.BracketedSequence(input)).x
+
+ train_input = torch.cat([bs.x for bs in tape_input], dim=0)
+ train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+
+ print(f"{train_input.size()=} {train_targets.size()=}")
+
+ exit(0)
+
+######################################################################
+######################################################################
+
+if args.test == "reject":
+ record = []
+
+ c_quizzes_procedure = [
+ (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
+ (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+ (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+ (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+ (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+ (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
+ (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+ ]
+
+ while sum([x.size(0) for x in record]) < 64:
+ model = models[torch.randint(len(models), (1,)).item()]
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ 64,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ )
+
+ p = quiz_machine.models_logprobas(
+ model,
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (1, 1, 1, 1),
+ temperature=1,
+ ).exp()
+
+ p_hot = quiz_machine.models_logprobas(
+ model,
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (1, 1, 1, 1),
+ temperature=args.temperature_hot,
+ ).exp()
+
+ to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p
+ record.append(c_quizzes[to_keep])
+
+ print("NB_KEPT", sum([x.size(0) for x in record]))
+
+ filename = f"sampling_with_rejection.png"
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=c_quizzes,
+ )
+
+ log_string(f"wrote {filename}")
+
exit(0)
######################################################################
log_string(f"wrote {filename}")
for model in weakest_models:
- save_additional_results(model, models, c_quizzes_procedure)
+ save_additional_results(n_epoch, model, models, c_quizzes_procedure)
######################################################################