self.prompt_len = None
self.answer_len = None
- self.configurations = [
+ self.train_struct = [
("A", "f_A", "B", "f_B"), # The standard order
("f_A", "A", "f_B", "B"), # The reverse order for validation
("f_B", "f_A", "A", "B"), # The synthesis order
######################################################################
+ def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+ assert struct in self.train_struct
+ return self.problem.make_ar_mask(quizzes, struct, mask)
+
def predict(self, model, quizzes, struct, mask):
- ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+ ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
seq_logproba = torch.empty(quizzes.size(0), device=self.device)
predicted_parts = predicted_parts[:128]
correct_parts = correct_parts[:128]
- self.problem.reconfigure(
+ result, predicted_parts, correct_parts = self.problem.reconfigure(
[result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B")
)
model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
self.randomize_configuations_inplace(
- model.train_w_quizzes, configurations=self.configurations
+ model.train_w_quizzes, configurations=self.train_struct
)
self.randomize_configuations_inplace(
- model.test_w_quizzes, configurations=self.configurations
+ model.test_w_quizzes, configurations=self.train_struct
)
######################################################################
)
self.randomize_configuations_inplace(
- model.train_w_quizzes, configurations=self.configurations
+ model.train_w_quizzes, configurations=self.train_struct
)
######################################################################
c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
):
input = input.to(self.device)
- ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
+ ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123")
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
# A, f(A), B | f(B)
result = c_quizzes.clone()
- ar_mask = self.problem.make_ar_mask(
+ ar_mask = self.make_ar_mask(
result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
)
# f(A), A, f(B) | B
result = reversed_c_quizzes.clone()
- ar_mask = self.problem.make_ar_mask(
+ ar_mask = self.make_ar_mask(
result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
)
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(
+ ar_mask=self.make_ar_mask(
c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
),
seq_logproba=seq_logproba,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(
+ ar_mask=self.make_ar_mask(
c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
),
seq_logproba=seq_logproba,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(
+ ar_mask=self.make_ar_mask(
c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
),
seq_logproba=seq_logproba,