tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
+ self.token_forward = len(self.colors)
+ self.token_backward = self.token_forward + 1
self.height = 10
self.width = 10
self.cache_rec_coo = {}
def frame2img(self, x, scale=15):
x = x.reshape(x.size(0), self.height, -1)
- m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+ m = torch.logical_and(x >= 0, x < len(self.colors)).long()
x = self.colors[x * m].permute(0, 3, 1, 2)
s = x.shape
x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
margin=8,
):
S = self.height * self.width
- As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
- f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
+ As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view(
+ -1, self.height, self.width
+ )
+ f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view(
+ -1, self.height, self.width
+ )
+ Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
-1, self.height, self.width
)
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
prompts = torch.cat([As, f_As, Bs], dim=2)
- answers = answers.reshape(answers.size(0), self.height, self.width)
+ answers = answers[:, 1 : S + 1].reshape(
+ answers.size(0), self.height, self.width
+ )
if predicted_prompts is None:
predicted_prompts = 255
######################################################################
def nb_token_values(self):
- return len(self.colors)
+ return len(self.colors) + 2
# @torch.compile
def rec_coo(
def trivial_prompts_and_answers(self, prompts, answers):
S = self.height * self.width
- Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
- f_Bs = answers
+ Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
+ f_Bs = answers[:, 1:]
+ print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}")
return (Bs == f_Bs).long().min(dim=-1).values > 0
def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
tasks = self.all_tasks
S = self.height * self.width
- prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
- answers = torch.zeros(nb, S, dtype=torch.int64)
+ prompts = torch.full((nb, 3 * S + 3), self.token_forward)
+ answers = torch.full((nb, S + 1), self.token_forward)
bunch = zip(prompts, answers)
)
for prompt, answer in bunch:
- A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
- f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
- B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
- f_B = answer.view(self.height, self.width)
+ A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view(
+ self.height, self.width
+ )
+ f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view(
+ self.height, self.width
+ )
+ B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
+ self.height, self.width
+ )
+ f_B = answer[1 : S + 1].view(self.height, self.width)
task = tasks[torch.randint(len(tasks), (1,)).item()]
task(A, f_A, B, f_B)
parser.add_argument("--proba_not_understands", type=float, default=0.5)
-parser.add_argument("--generation_temperature", type=float, default=1.5)
+parser.add_argument("--temperature_hot", type=float, default=1.5)
+
+parser.add_argument("--temperature_cold", type=float, default=0.75)
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
-parser.add_argument("--forward_only", action="store_true", default=False)
+parser.add_argument("--p2a_only", action="store_true", default=False)
parser.add_argument("--dirty_debug", action="store_true", default=False)
acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
- n_forward = input[:, 0] == quiz_machine.token_forward
- to_store = from_w & n_forward.to("cpu")
+ n_p2a = input[:, 0] == quiz_machine.token_p2a
+ to_store = from_w & n_p2a.to("cpu")
if to_store.any():
hard_w_quizzes.append(
(input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))
# We balance the number of quizzes per model
model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
- print(nb_validated, "using", model_for_generation.id)
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
model_for_generation=model_for_generation,
- forward_only=args.forward_only,
- generation_temperature=args.generation_temperature,
+ p2a_only=args.p2a_only,
+ temperature_hot=args.temperature_hot,
+ temperature_cold=args.temperature_cold,
)
c_quizzes = keep_good_quizzes(models, c_quizzes)
model=model,
nb=args.nb_train_samples,
for_train=True,
- forward_only=args.forward_only,
+ p2a_only=args.p2a_only,
)
quiz_machine.create_w_quizzes(
model=model,
nb=args.nb_test_samples,
for_train=False,
- forward_only=args.forward_only,
+ p2a_only=args.p2a_only,
)
models.append(model)
quiz_machine.renew_w_quizzes(
model=model,
for_train=True,
- forward_only=args.forward_only,
+ p2a_only=args.p2a_only,
)
if args.log_command is not None:
class QuizMachine:
- def indices_forward_and_backward(self, quizzes):
- i_forward = quizzes[:, 0] == self.token_forward
- j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
- i_backward = quizzes[:, 0] == self.token_backward
- j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
+ def indices_p2a_and_a2p(self, quizzes):
+ i_p2a = quizzes[:, 0] == self.problem.token_forward
+ j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
+ i_a2p = quizzes[:, 0] == self.problem.token_backward
+ j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
assert torch.logical_or(
- torch.logical_and(i_forward, j_forward),
- torch.logical_and(i_backward, j_backward),
+ torch.logical_and(i_p2a, j_p2a),
+ torch.logical_and(i_a2p, j_a2p),
).all()
- return i_forward, i_backward
+ return i_p2a, i_a2p
def non_trivial(self, quizzes):
quizzes = quizzes.clone()
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+ n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
+ n_a2p = quizzes[:, 0] == self.problem.token_backward
+ a2p = quizzes[n_a2p]
+ quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
return torch.logical_not(
self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
+ quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
)
)
- def reverse_time(self, quizzes):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+ def p_a_flip(self, quizzes):
+ i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
- forward_to_backward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
- quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
- quizzes[:, 1 : 1 + self.prompt_len],
- ],
+ p2a_to_a2p = torch.cat(
+ [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]],
dim=1,
)
- forward_to_backward[:, 0] = self.token_backward
- forward_to_backward[:, 1 + self.answer_len] = self.token_backward
+ p2a_to_a2p[:, 0] = self.problem.token_backward
+ p2a_to_a2p[:, self.answer_len] = self.problem.token_backward
- backward_to_forward = torch.cat(
- [
- quizzes[:, 0:1],
- quizzes[:, 2 + self.answer_len :],
- quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
- quizzes[:, 1 : 1 + self.answer_len],
- ],
+ a2p_to_p2a = torch.cat(
+ [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]],
dim=1,
)
- backward_to_forward[:, 0] = self.token_forward
- backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
+ a2p_to_p2a[:, 0] = self.problem.token_forward
+ a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward
- m = i_forward.long()[:, None]
+ m = i_p2a.long()[:, None]
- return m * forward_to_backward + (1 - m) * backward_to_forward
+ return m * p2a_to_a2p + (1 - m) * a2p_to_p2a
- def reverse_random_half_in_place(self, quizzes):
+ def p_a_flip_half_in_place(self, quizzes):
i = torch.rand(quizzes.size(0)) < 0.5
if i.any():
- quizzes[i] = self.reverse_time(quizzes[i])
+ quizzes[i] = self.p_a_flip(quizzes[i])
def make_ar_mask(self, quizzes, first=False):
- i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+ i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
t = torch.arange(quizzes.size(1), device=quizzes.device)
if first:
- m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
- m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
+ m_p2a = (t >= 1).long() * (t < self.prompt_len).long()
+ m_a2p = (t >= 1).long() * (t < self.answer_len).long()
else:
- m_forward = (t >= 2 + self.prompt_len).long()
- m_backward = (t >= 2 + self.answer_len).long()
+ m_p2a = (t >= 1 + self.prompt_len).long()
+ m_a2p = (t >= 1 + self.answer_len).long()
- m = i_forward.long()[:, None]
+ m = i_p2a.long()[:, None]
- return m * m_forward + (1 - m) * m_backward
+ return m * m_p2a + (1 - m) * m_a2p
def generate_token_sequences(self, nb):
prompts, answers = self.problem.generate_prompts_and_answers(nb)
result = []
for prompt, answer in zip(prompts, answers):
- a = [
- torch.tensor([self.token_forward]),
- prompt,
- torch.tensor([self.token_forward]),
- answer,
- ]
-
- result.append(torch.cat(a, dim=0)[None, :])
+ result.append(torch.cat([prompt, answer], dim=0)[None, :])
return torch.cat(result, dim=0)
):
super().__init__()
- v = problem.nb_token_values()
- self.token_forward = v
- self.token_backward = v + 1
- self.nb_token_values = v + 2
+ self.nb_token_values = problem.nb_token_values()
self.problem = problem
self.back_accuracy = back_accuracy
show_part_to_predict=True,
):
quizzes = quizzes.clone().to("cpu")
- n_forward = quizzes[quizzes[:, 0] == self.token_forward]
- n_backward = quizzes[:, 0] == self.token_backward
- backward = quizzes[n_backward]
- assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
- quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+ n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
+ n_a2p = quizzes[:, 0] == self.problem.token_backward
+ a2p = quizzes[n_a2p]
+ assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
+ quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
if show_part_to_predict:
- predicted_prompts = n_backward.long()
+ predicted_prompts = n_a2p.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
- n_forward = input[:, 0] == self.token_forward
- n_backward = input[:, 0] == self.token_backward
+ n_p2a = input[:, 0] == self.problem.token_forward
+ n_a2p = input[:, 0] == self.problem.token_backward
- correct[n_forward] = (
- (input[n_forward] == result[n_forward]).long().min(dim=1).values
- )
+ correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values
- if self.back_accuracy and n_backward.any():
+ if self.back_accuracy and n_a2p.any():
# accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.reverse_time(result[n_backward])
+ back_input = self.p_a_flip(result[n_a2p])
back_input[:, 2 + self.prompt_len :] = input[
- n_backward, 1 : 1 + self.answer_len
+ n_a2p, 1 : 1 + self.answer_len
]
- _, correct[n_backward] = compute_accuracy(back_input)
+ _, correct[n_a2p] = compute_accuracy(back_input)
if log_prefix is not None:
- forward_nb_correct = correct[n_forward].sum()
- forward_nb_total = correct[n_forward].size(0)
- backward_nb_correct = correct[n_backward].sum()
- backward_nb_total = correct[n_backward].size(0)
+ p2a_nb_correct = correct[n_p2a].sum()
+ p2a_nb_total = correct[n_p2a].size(0)
+ a2p_nb_correct = correct[n_a2p].sum()
+ a2p_nb_total = correct[n_a2p].size(0)
self.logger(
- f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
+ f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"
)
return result, correct
model.test_w_quizzes[:2000], log_prefix="test"
)
- n_test_forward = model.test_w_quizzes[:2000, 0] == self.token_forward
+ n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward
- forward_test_correct = test_correct[n_test_forward]
+ p2a_test_correct = test_correct[n_test_p2a]
- main_test_accuracy = forward_test_correct.sum() / forward_test_correct.size(0)
+ main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0)
##############################
######################################################################
- def create_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+ def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False):
input = self.generate_token_sequences(nb)
- if not forward_only:
- self.reverse_random_half_in_place(input)
+ if not p2a_only:
+ self.p_a_flip_half_in_place(input)
if for_train:
model.train_w_quizzes = input
######################################################################
- def renew_w_quizzes(self, model, for_train=True, forward_only=False):
+ def renew_w_quizzes(self, model, for_train=True, p2a_only=False):
input = model.train_w_quizzes if for_train else model.test_w_quizzes
if for_train and hasattr(model, "hard_w_quizzes"):
else:
input[...] = self.generate_token_sequences(input.size(0))
- if not forward_only:
- self.reverse_random_half_in_place(input)
+ if not p2a_only:
+ self.p_a_flip_half_in_place(input)
######################################################################
###############################################################
def generate_c_quizzes(
- self, nb, model_for_generation, forward_only=False, generation_temperature=1.0
+ self,
+ nb,
+ model_for_generation,
+ p2a_only=False,
+ temperature_hot=1.0,
+ temperature_cold=1.0,
):
c_quizzes = torch.empty(
nb,
- self.prompt_len + self.answer_len + 2,
+ self.prompt_len + self.answer_len,
device=self.device,
dtype=torch.int64,
)
seq_logproba = torch.zeros(nb, device=self.device)
- if forward_only:
- c_quizzes[:, 0] = self.token_forward
- c_quizzes[:, 1 + self.prompt_len] = self.token_forward
+ if p2a_only:
+ c_quizzes[:, 0] = self.problem.token_forward
+ c_quizzes[:, self.prompt_len] = self.problem.token_forward
masked_inplace_autoregression(
model=model_for_generation,
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
- temperature=generation_temperature,
+ temperature=temperature_hot,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=1.0,
+ temperature=temperature_cold,
deterministic_synthesis=False,
device=self.device,
)
else:
- c_quizzes[:, 0] = self.token_backward
- c_quizzes[:, 1 + self.answer_len] = self.token_backward
+ c_quizzes[:, 0] = self.problem.token_backward
+ c_quizzes[:, self.answer_len] = self.problem.token_backward
masked_inplace_autoregression(
model=model_for_generation,
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
- temperature=generation_temperature,
+ temperature=temperature_hot,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=0.75,
+ temperature=temperature_cold,
deterministic_synthesis=False,
device=self.device,
)
- c_quizzes = self.reverse_time(c_quizzes)
+ c_quizzes = self.p_a_flip(c_quizzes)
masked_inplace_autoregression(
model=model_for_generation,
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=0.75,
+ temperature=temperature_cold,
deterministic_synthesis=False,
device=self.device,
)