######################################################################
-# This is the key routine that decides what generated quizzes to keep
+def keep_good_quizzes(models, quizzes):
+ quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
+ token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
-# token_logprobas are NxMxT where M is the number of models
-# def compute_valid_quizzes_(token_logprobas):
-# warnings.warn("validation with uniform constraints", RuntimeWarning)
-# l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
-# return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
-
-# token_logprobas are NxMxT where M is the number of models
-
-
-def compute_valid_quizzes(token_logprobas):
l = token_logprobas.sum(dim=-1).sort(dim=-1).values
- return (l[:, 0] < math.log(args.proba_not_understands)) & (
+
+ to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
l[:, 1] > math.log(args.proba_understands)
)
+ if args.dirty_debug:
+ # warnings.warn("DEBUG", RuntimeWarning)
+ to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5
-def extract_valid_quizzes_and_logprobas(recorded):
- validated_quizzes, validated_logprobas = [], []
- for quizzes, token_logprobas in recorded:
- validated_indices = compute_valid_quizzes(token_logprobas)
- validated_quizzes.append(quizzes[validated_indices])
- validated_logprobas.append(token_logprobas[validated_indices])
-
- if len(validated_quizzes) > 0:
- return torch.cat(validated_quizzes, dim=0), torch.cat(
- validated_logprobas, dim=0
- )
- else:
- return None, None
+ return quizzes[to_keep]
######################################################################
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
+def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
nb_to_create = nb_for_train + nb_for_test
-
- recorded_quizzes_logprobas = []
-
+ nb_to_generate_per_iteration = nb_to_create
nb_validated = 0
- start_time = time.perf_counter()
+ recorded = []
- nb_to_generate_per_iteration = nb_to_create
+ start_time = time.perf_counter()
while nb_validated < nb_to_create:
model_for_generation = models[torch.randint(len(models), (1,))]
temperature=args.generation_temperature,
)
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+ c_quizzes = keep_good_quizzes(models, c_quizzes)
- if c_quizzes.size(0) > 0:
- token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
- recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+ nb_validated += c_quizzes.size(0)
- (
- validated_quizzes,
- validated_logprobas,
- ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
-
- if validated_quizzes is not None:
- nb_validated = validated_quizzes.size(0)
+ recorded.append(c_quizzes)
duration = time.perf_counter() - start_time
f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
)
+ validated_quizzes = torch.cat(recorded, dim=0)
+
+ ######################################################################
# store the new c_quizzes which have been validated
v_train = validated_quizzes[:nb_for_train]
quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False)
######################################################################
- # save images with their logprobas
+ # save images
vq = validated_quizzes[:128]
- vl = validated_logprobas[:128]
if vq.size(0) > 0:
prefix = f"culture_c_quiz_{n_epoch:04d}"
- filename = os.path.join(args.result_dir, prefix + "_logp.pth")
- torch.save(vl, filename)
- # with open(file_name, "w") as logp_file:
- # for l in vl:
- # s = " ".join([str(x.item()) for x in l])
- # logp_file.write(s + "\n")
-
quiz_machine.save_quiz_illustrations(
args.result_dir, prefix, vq, show_part_to_predict=False
)
# re-compute the test errors
if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- create_c_quizzes(
+ record_new_c_quizzes(
models,
quiz_machine,
nb_for_train=args.nb_new_c_quizzes_for_train,