parser.add_argument("--proba_not_understands", type=float, default=0.5)
-parser.add_argument("--generation_temperature", type=float, default=2)
+# parser.add_argument("--generation_temperature", type=float, default=2)
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
+parser.add_argument("--forward_only", action="store_true", default=False)
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
######################################################################
while nb_validated < nb_to_create:
model_for_generation = models[torch.randint(len(models), (1,))]
- c_quizzes = quiz_machine.generate_quizzes(
+ c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
+ forward_only=args.forward_only,
)
c_quizzes = keep_good_quizzes(models, c_quizzes)
model.main_test_accuracy = 0.0
model.id = k
- model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
- quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
- model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
- quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+ quiz_machine.create_w_quizzes(
+ model=model,
+ nb=args.nb_train_samples,
+ for_train=True,
+ forward_only=args.forward_only,
+ )
+
+ quiz_machine.create_w_quizzes(
+ model=model,
+ nb=args.nb_test_samples,
+ for_train=False,
+ forward_only=args.forward_only,
+ )
models.append(model)
# Renew the training samples
for model in weakest_models:
- quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
-
+ quiz_machine.renew_w_quizzes(
+ model=model,
+ nb=args.nb_train_samples,
+ for_train=True,
+ forward_only=args.forward_only,
+ )
######################################################################
######################################################################
- def renew_w_quizzes(self, model, nb, for_train=True):
+ def create_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+ input = self.generate_token_sequences(nb)
+
+ if not forward_only:
+ self.reverse_random_half_in_place(input)
+
+ if for_train:
+ model.train_w_quizzes = input
+ else:
+ model.test_w_quizzes = input
+
+ ######################################################################
+
+ def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False):
input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
- self.reverse_random_half_in_place(fresh_w_quizzes)
+ if not forward_only:
+ self.reverse_random_half_in_place(fresh_w_quizzes)
input[-nb:] = fresh_w_quizzes.to("cpu")
######################################################################
###############################################################
- def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
+ def generate_c_quizzes(self, nb, model_for_generation, forward_only=False):
c_quizzes = torch.empty(
nb,
self.prompt_len + self.answer_len + 2,
seq_logproba = torch.zeros(nb, device=self.device)
- c_quizzes[:, 0] = self.token_forward
- c_quizzes[:, 1 + self.prompt_len] = self.token_forward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, first=True),
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=False,
- device=self.device,
- )
+ if forward_only:
+ c_quizzes[:, 0] = self.token_forward
+ c_quizzes[:, 1 + self.prompt_len] = self.token_forward
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
- seq_logproba=seq_logproba,
- temperature=1,
- deterministic_synthesis=False,
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(c_quizzes, first=True),
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(c_quizzes),
+ seq_logproba=seq_logproba,
+ temperature=1,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ else:
+ c_quizzes[:, 0] = self.token_backward
+ c_quizzes[:, 1 + self.answer_len] = self.token_backward
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(c_quizzes, first=True),
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(c_quizzes),
+ seq_logproba=seq_logproba,
+ temperature=1,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ c_quizzes = self.reverse_time(c_quizzes)
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(c_quizzes),
+ seq_logproba=seq_logproba,
+ temperature=1,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
return c_quizzes.to("cpu")