nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
while nb_validated_per_model.sum() < nb_to_validate:
- # We balance the number of quizzes per model
+ # We use the model that has generated the fewest quizzes to
+ # balance the number of quizzes per model overall
model_for_generation = sorted(
models, key=lambda m: nb_validated_per_model[m.id]
temperature_cold=args.temperature_cold,
)
- # We discard the trivial ones
+ # We discard the trivial ones, according to a criterion
+ # specific to the world quizzes (e.g. B=f(B))
c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
# We go through nb_rounds rounds and keep only quizzes on
- # which models respond always the same through rounds and one
- # which N-1 succeed and one fails
+ # which
+ #
+ # (1) models respond always the same through rounds, and
+ #
+ # (2) at least one and up to max_fail_to_validate model(s)
+ # fail(s)
- ms = 0 # "model scores"
+ # This is nb_quizzes x nb_models
+ number_correct_responses = 0
for r in range(args.nb_rounds):
- ms += quiz_machine.models_successes(models, c_quizzes)
- nb_sure_and_correct = (ms == r + 1).long().sum(dim=1)
- nb_sure_and_fail = (ms == 0).long().sum(dim=1)
+ number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
+
+ nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
+ nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1)
+
to_keep = (
- (nb_sure_and_correct + nb_sure_and_fail == ms.size(1))
- & (nb_sure_and_fail >= 1)
- & (nb_sure_and_fail <= args.max_fail_to_validate)
+ (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1))
+ & (nb_sure_fail >= 1)
+ & (nb_sure_fail <= args.max_fail_to_validate)
)
c_quizzes = c_quizzes[to_keep]
- ms = ms[to_keep]
- print(f"Round {r} remains {c_quizzes.size(0)}")
+ number_correct_responses = number_correct_responses[to_keep]
+
+ log_string(f"round {r} remains {c_quizzes.size(0)}")
+
if c_quizzes.size(0) == 0:
break
for k in range(args.nb_gpts):
log_string(f"creating model {k} and its w_quizzes")
+
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
quiz_machine.create_w_quizzes(
model=model,
- nb=args.nb_train_samples,
- for_train=True,
- p2a_only=args.p2a_only,
- )
-
- quiz_machine.create_w_quizzes(
- model=model,
- nb=args.nb_test_samples,
- for_train=False,
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
p2a_only=args.p2a_only,
)
# Renew the training samples
for model in weakest_models:
- quiz_machine.renew_w_quizzes(
- model=model,
- for_train=True,
- p2a_only=args.p2a_only,
- )
+ quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only)
if args.log_command is not None:
s = args.log_command.split()
######################################################################
- def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False):
- input = self.generate_token_sequences(nb)
+ def create_w_quizzes(
+ self, model, nb_train_samples, nb_test_samples, p2a_only=False
+ ):
+ model.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
+ model.test_w_quizzes = self.generate_token_sequences(nb_test_samples)
if not p2a_only:
- self.p_a_flip_half_in_place(input)
-
- if for_train:
- model.train_w_quizzes = input
- else:
- model.test_w_quizzes = input
+ self.p_a_flip_half_in_place(model.train_w_quizzes)
+ self.p_a_flip_half_in_place(model.test_w_quizzes)
######################################################################
- def renew_w_quizzes(self, model, for_train=True, p2a_only=False):
- input = model.train_w_quizzes if for_train else model.test_w_quizzes
-
+ def renew_train_w_quizzes(self, model, p2a_only=False):
if for_train and hasattr(model, "hard_w_quizzes"):
self.logger(
f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
)
- if model.hard_w_quizzes.size(0) >= input.size(0):
- input[...] = model.hard_w_quizzes[
- torch.randperm(hard_w_quizzes.size(0))[input.size(0)]
+
+ if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+ model.train_w_quizzes[...] = model.hard_w_quizzes[
+ torch.randperm(hard_w_quizzes.size(0))[
+ model.train_w_quizzes.size(0)
+ ]
]
else:
- input[...] = torch.cat(
+ model.train_w_quizzes[...] = torch.cat(
[
model.hard_w_quizzes,
self.generate_token_sequences(
- input.size(0) - model.hard_w_quizzes.size(0)
+ model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
),
],
dim=0,
)
else:
- input[...] = self.generate_token_sequences(input.size(0))
+ model.train_w_quizzes[...] = self.generate_token_sequences(
+ model.train_w_quizzes.size(0)
+ )
if not p2a_only:
- self.p_a_flip_half_in_place(input)
+ self.p_a_flip_half_in_place(model.train_w_quizzes)
######################################################################