("gray", [128, 128, 128]),
]
- def make_ar_mask(self, quizzes, first=False):
+ def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"):
S = self.height * self.width
assert (
T = torch.arange(quizzes.size(1), device=quizzes.device)
- if first:
+ if shape == "fwd_3_bck_123":
+ forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+ backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+ elif shape == "fwd_012_bck_0":
forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
- else:
+ elif shape == "fwd_3_bck_3":
forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+ backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+ else:
+ raise ValueError(shape)
is_forward = (quizzes[:, 0] == self.token_forward).long()
+ (1 - is_forward)[:, None] * backward_mask[None, :]
)
- def p_a_flip(self, quizzes):
+ def p_a_flip(self, quizzes, pairwise_flip=False):
S = self.height * self.width
assert (
& (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
).all()
- flipped = torch.cat(
- [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)],
- dim=1,
- )
+ if pairwise_flip:
+ flipped = torch.cat(
+ [
+ quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+ quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
+ quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
+ quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+ ],
+ dim=1,
+ )
+ else:
+ flipped = torch.cat(
+ [
+ quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
+ quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+ quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+ quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
+ ],
+ dim=1,
+ )
m = (flipped[:, 0] == self.token_forward).long()
flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
def compute_accuracy(input, log_prefix=None):
input = input.to(self.device)
- ar_mask = self.problem.make_ar_mask(input)
+ ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
):
input = input.to(self.device)
- ar_mask = self.problem.make_ar_mask(input)
+ ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
###############################################################
def solution_nb_correct(
- self,
- models_for_validation,
- c_quizzes,
- deterministic_validation=False,
+ self, models_for_validation, c_quizzes, bidirectional_validation=True
):
seq_logproba = torch.zeros(
c_quizzes.size(0),
seq_logproba[...] = 0.0
for model in models_for_validation:
+ # A, f(A), B | f(B)
c_quizzes = c_quizzes.to(self.device)
result = c_quizzes.clone()
- ar_mask = self.problem.make_ar_mask(result)
+ ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
masked_inplace_autoregression(
model=model,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
temperature=1.0,
- deterministic_synthesis=deterministic_validation,
+ deterministic_synthesis=False,
device=self.device,
)
correct = (c_quizzes == result).long().min(dim=-1).values
- nb_correct += correct
+ # -------------------------------
+
+ # f(A), A, f(B) | B
+ c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to(
+ self.device
+ )
+ result = c_quizzes.clone()
+
+ ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba[:, model.id],
+ temperature=1.0,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ flipped_correct = (c_quizzes == result).long().min(dim=-1).values
+
+ # -------------------------------
+
+ nb_correct += correct * flipped_correct
return nb_correct.to("cpu")
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
temperature=temperature_hot,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
temperature=temperature_hot,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,