nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- full_input, full_mask_loss = quiz_machine.data_input(model, split="test")
+ full_input, full_mask_loss = quiz_machine.data_input(
+ model, args.nb_test_samples
+ )
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
)
hard_w_quizzes = []
- full_input, full_mask_loss = quiz_machine.data_input(model, split="train")
+ full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples)
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
for input, mask_loss in tqdm.tqdm(
nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
while nb_validated_per_model.sum() < nb_to_validate:
- # We use the model that has generated the fewest quizzes to
- # balance the number of quizzes per model overall
-
- # model_for_generation = sorted(
- # models, key=lambda m: nb_validated_per_model[m.id]
- # )[0]
-
model_for_generation = models[torch.randint(len(models), (1,)).item()]
# We generate quizzes with a procedure that injects some
# This is nb_quizzes x nb_models
+ solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
+
+ for m in models:
+ solved_c_quizzes[:, m.id] = quiz_machine.predict(
+ m,
+ solved_c_quizzes[:, m.id],
+ struct=("A", "f_A", "B", "f_B"),
+ mask=(0, 0, 0, 1),
+ )
+
+ # FINISH
+
seq_logproba = quiz_machine.models_logprobas(
models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
) + quiz_machine.models_logprobas(
).to(main_device)
model.id = k
+ model.c_quiz_bags = []
if args.schedule_free:
model.optimizer = schedulefree.AdamWScheduleFree(
model.main_test_accuracy = 0.0
- model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
- args.nb_train_samples
- )
-
- model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
-
models.append(model)
######################################################################
record_new_c_quizzes(
models,
quiz_machine,
- nb_for_train=args.nb_new_c_quizzes_for_train,
+ nb_errorsfor_train=args.nb_new_c_quizzes_for_train,
nb_for_test=args.nb_new_c_quizzes_for_test,
)
######################################################################
- # Renew the training samples
-
- for model in weakest_models:
- quiz_machine.renew_train_w_quizzes(model=model)
-
if args.log_command is not None:
s = args.log_command.split()
s.insert(1, args.result_dir)
self.test_structures = self.train_structures
- self.LOCK_C_QUIZZES = threading.Lock()
- self.train_c_quizzes = []
- self.test_c_quizzes = []
-
def vocabulary_size(self):
return self.problem.nb_token_values
######################################################################
- def data_input(self, model, split="train"):
- assert split in {"train", "test"}
-
- with self.LOCK_C_QUIZZES:
- if split == "train":
- w_quizzes = model.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = model.test_w_quizzes
- c_quizzes = self.test_c_quizzes
-
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
+ def data_input(self, model, nb_samples):
+ if len(model.c_quiz_bags) > 0:
+ c_quizzes = torch.cat(model.c_quiz_bags, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
-
- quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
- from_w = torch.arange(
- quizzes.size(0), device=quizzes.device
- ) < w_quizzes.size(0)
-
- else:
- quizzes = w_quizzes.clone()
- from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
+ w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ else:
+ quizzes = self.problem.generate_w_quizzes(nb_samples)
i = torch.randperm(quizzes.size(0), device=quizzes.device)
- quizzes, from_w = quizzes[i], from_w[i]
+ quizzes = quizzes[i]
self.randomize_configuations_inplace(
quizzes, structs=[s for s, _, _, _ in self.train_structures]
######################################################################
- def renew_train_w_quizzes(self, model):
- if hasattr(model, "hard_w_quizzes"):
- hard_w_quizzes = self.problem.reconfigure(
- model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B")
- )
- self.logger(
- f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
- )
- if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
- nb_to_generate = 0
- model.train_w_quizzes[...] = hard_w_quizzes[
- torch.randperm(hard_w_quizzes.size(0))[
- model.train_w_quizzes.size(0)
- ]
- ]
- else:
- nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0)
- model.train_w_quizzes[...] = torch.cat(
- [
- hard_w_quizzes,
- self.problem.generate_w_quizzes(nb_to_generate),
- ],
- dim=0,
- )
- else:
- nb_to_generate = 0
- model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
- model.train_w_quizzes.size(0)
- )
-
- ######################################################################
-
def store_c_quizzes(self, new_c_quizzes, for_train=True):
with self.LOCK_C_QUIZZES:
if for_train: