parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--max_fail_to_validate", type=int, default=1)
+parser.add_argument("--max_fail_to_validate", type=int, default=2)
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98)
-parser.add_argument("--proba_understands", type=float, default=0.99)
+parser.add_argument("--proba_understands", type=float, default=0.9)
parser.add_argument("--proba_not_understands", type=float, default=0.5)
def model_transformer_hot(model):
# model.temperature = args.temperature_hot
- model.set_noise_injection(1.0, ("ffw", 2))
+ model.set_noise_injection(0.5, ("ffw", args.nb_blocks // 2))
def model_transformer_cold(model):
# We use the model that has generated the fewest quizzes to
# balance the number of quizzes per model overall
- model_for_generation = sorted(
- models, key=lambda m: nb_validated_per_model[m.id]
- )[0]
+ # model_for_generation = sorted(
+ # models, key=lambda m: nb_validated_per_model[m.id]
+ # )[0]
+
+ model_for_generation = models[torch.randint(len(models), (1,)).item()]
# We generate quizzes with a procedure that injects some
# structured noise
######################################################################
+#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+if False:
+ for model in models:
+ for b in range(args.nb_blocks):
+ for o in [0.5, 1.0, 2.0]:
+
+ def model_transformer_hot(model):
+ # model.temperature = args.temperature_hot
+ model.set_noise_injection(o, ("ffw", b))
+
+ def model_transformer_cold(model):
+ pass
+
+ # model.temperature = args.temperature_cold
+
+ c_quizzes_procedure = [
+ # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+ # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+ # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_transformer_hot),
+ ]
+
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ 128, model_for_generation=model, procedure=c_quizzes_procedure
+ )
+
+ filename = f"generated_{b:02d}_{o:.02f}_{model.id:02d}.png"
+ print(filename)
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ c_quizzes,
+ )
+ exit(0)
+
+#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
if args.dirty_debug:
args.accuracy_to_make_c_quizzes = 0.0
args.nb_gpts = 2