parser.add_argument("--min_succeed_to_validate", type=int, default=2)
-parser.add_argument("--max_fail_to_validate", type=int, default=3)
-
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--proba_understands", type=float, default=0.95)
-
-parser.add_argument("--proba_not_understands", type=float, default=0.1)
-
-parser.add_argument("--temperature_hot", type=float, default=1.5)
-
-parser.add_argument("--temperature_cold", type=float, default=1)
-
parser.add_argument("--prompt_noise", type=float, default=0.05)
+parser.add_argument("--nb_hints", type=int, default=5)
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
parser.add_argument("--test", type=str, default=None)
x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- # filename = f"debug.png"
-
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename,
- # quizzes=x_t,
- # )
-
- # log_string(f"wrote {filename}")
- # exit(0)
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
# We may inject noise to prevent high-complexity non-structure
# signal to be generated as a way of "increasing reasoning
# complexity"
######################################################################
-def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
+def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+ if mask_hints is None:
+ x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+ else:
+ mask = mask_generate * (1 - mask_hints)
+ x_t = (1 - mask) * x_0 + mask * noise
one_iteration_prediction = deterministic(mask_generate)[:, None]
# Save some images
- if n_epoch < 50:
+ if n_epoch < 100:
for f, record in [("prediction", record_d), ("generation", record_nd)]:
result, predicted_parts, correct_parts = bag_to_tensors(record)
######################################################################
-def quiz_validation(models, c_quizzes, local_device):
- nb_have_to_be_correct = 3
- nb_have_to_be_wrong = 1
- nb_mistakes_to_be_wrong = 5
-
+def quiz_validation(
+ models,
+ c_quizzes,
+ local_device,
+ nb_have_to_be_correct=3,
+ nb_have_to_be_not_correct=0,
+ nb_have_to_be_wrong=1,
+ nb_mistakes_to_be_wrong=5,
+ nb_hints=0,
+ nb_runs=1,
+):
record_wrong = []
nb_correct, nb_wrong = 0, 0
for i, model in enumerate(models):
assert i == model.id # a bit of paranoia
model = copy.deepcopy(model).to(local_device).eval()
-
correct, wrong = True, False
-
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
- quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ quizzes=c_quizzes,
+ quad_order=("A", "f_A", "B", "f_B"),
+ quad_mask=quad,
)
- result = ae_generate(model, (1 - mask_generate) * c_quizzes, mask_generate)
- nb_mistakes = (result != c_quizzes).long().sum(dim=1)
- correct = correct & (nb_mistakes == 0)
- wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+ for _ in range(nb_runs):
+ if nb_hints == 0:
+ mask_hints = None
+ else:
+ u = (
+ torch.rand(mask_generate.size(), device=mask_generate.device)
+ * mask_generate
+ )
+ mask_hints = (
+ u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
+ ).long()
+
+ result = ae_generate(
+ model=model,
+ x_0=(1 - mask_generate) * c_quizzes,
+ mask_generate=mask_generate,
+ mask_hints=mask_hints,
+ )
+
+ nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+ correct = correct & (nb_mistakes == 0)
+ wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
record_wrong.append(wrong[:, None])
nb_correct += correct.long()
c_quizzes = ae_generate(model, template, mask_generate)
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- ## for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
- ## mask_generate = quiz_machine.make_quiz_mask(
- ## quizzes=c_quizzes,
- ## quad_order=("A", "f_A", "B", "f_B"),
- ## quad_mask=quad,
- ## )
- ## c_quizzes = ae_generate(
- ## model,
- ## (1 - mask_generate) * c_quizzes,
- ## mask_generate,
- ## )
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
+ to_keep, record_wrong = quiz_validation(
+ models, c_quizzes, local_device, nb_hints=args.nb_hints
+ )
q = c_quizzes[to_keep]
if q.size(0) > 0:
######################################################################
-def save_c_quizzes_with_scores(models, c_quizzes, filename):
+def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False):
l = []
+ if solvable_only:
+ to_keep, _ = quiz_validation(
+ models,
+ c_quizzes,
+ main_device,
+ nb_have_to_be_correct=1,
+ nb_have_to_be_wrong=0,
+ nb_hints=0,
+ )
+ c_quizzes = c_quizzes[to_keep]
+
+ c_quizzes = c_quizzes[
+ torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
+ ]
+
with torch.autograd.no_grad():
for model in models:
model = copy.deepcopy(model).to(main_device).eval()
# --------------------------------------------------------------------
filename = f"culture_c_quiz_{n_epoch:04d}.png"
- save_c_quizzes_with_scores(models, c_quizzes[:128], filename)
+ save_c_quizzes_with_scores(
+ models, c_quizzes, 256, filename, solvable_only=False
+ )
+ filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png"
+ save_c_quizzes_with_scores(models, c_quizzes, 256, filename, solvable_only=True)
log_string(f"generated_c_quizzes {c_quizzes.size()=}")