self.check_structure(quizzes, struct)
return struct
- def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
- assert check_structure(quizzes, struct)
-
- ar_mask = quizzes.new_zeros(quizzes.size())
-
- a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
- a[:, 0, :] = mask[0]
- a[:, 1, :] = mask[1]
- a[:, 2, :] = mask[2]
- a[:, 3, :] = mask[3]
-
- return ar_mask
-
def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
S = self.height * self.width
return result
+ def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+ assert check_structure(quizzes, struct)
+
+ ar_mask = quizzes.new_zeros(quizzes.size())
+
+ a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+ a[:, 0, :] = mask[0]
+ a[:, 1, :] = mask[1]
+ a[:, 2, :] = mask[2]
+ a[:, 3, :] = mask[3]
+
+ return ar_mask
+
+ def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ q = quizzes.reshape(-1, 4, S + 1)
+ return (
+ (q[:, 0, 0] == self.l2tok[struct[0]])
+ & (q[:, 1, 0] == self.l2tok[struct[1]])
+ & (q[:, 2, 0] == self.l2tok[struct[2]])
+ & (q[:, 3, 0] == self.l2tok[struct[3]])
+ )
+
def __init__(
self,
max_nb_cached_chunks=None,
######################################################################
- def trivial_prompts_and_answers(self, prompts, answers):
- S = self.height * self.width
- Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
- f_Bs = answers[:, 1:]
- return (Bs == f_Bs).long().min(dim=-1).values > 0
-
def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
if tasks is None:
tasks = self.all_tasks
nb = 5
quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
print(grids.get_structure(quizzes))
- blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
- print(grids.get_structure(blah))
+ quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+ print(grids.get_structure(quizzes))
+
+ i = torch.rand(quizzes.size(0)) < 0.5
+
+ quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
+
+ j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
+
+ print(
+ i.equal(j),
+ grids.get_structure(quizzes[j]),
+ grids.get_structure(quizzes[j == False]),
+ )
+
exit(0)
# nb = 1000
import threading
-######################################################################
-# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
-# | X != Y)
-
-
-# output is NxCxT and target is NxT
-def confusion(output, target, reduction="mean"):
- N, C, T = output.shape
- output = output.permute(0, 2, 1).reshape(-1, C)
- target = target.flatten()
- all_t = torch.arange(N * T, device=output.device)
- output = output.log_softmax(dim=-1)
- result = -output[all_t, target]
-
- output[all_t, target] = float("-inf")
- output = output.log_softmax(dim=-1)
- e = output.exp()
- output[all_t, target] = 0
- result = result - (output * e).sum(-1)
-
- if reduction == "none":
- return result.reshape(N, T)
- elif reduction == "mean":
- return result.reshape(N, T).mean()
- elif reduction == "sum":
- return result.reshape(N, T).sum()
- else:
- raise ValueError(f"unknown reduction '{reduction}'.")
-
-
######################################################################
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
class QuizMachine:
- def indices_p2a_and_a2p(self, quizzes):
- i_p2a = quizzes[:, 0] == self.problem.token_forward
- j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward
- i_a2p = quizzes[:, 0] == self.problem.token_backward
- j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward
- assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all()
- return i_p2a, i_a2p
-
- def non_trivial(self, quizzes):
- quizzes = quizzes.clone()
- i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
- quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) # a_fa_b_fb
- return torch.logical_not(
- self.problem.trivial_prompts_and_answers(
- quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
- )
- )
-
- def p_a_flip_half_in_place(self, quizzes):
- i = torch.rand(quizzes.size(0)) < 0.5
- if i.any():
- quizzes[i] = self.problem.p_a_flip(quizzes[i])
-
- def generate_token_sequences(self, nb):
- prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
- if self.prompt_len is None:
- self.prompt_len = prompts.size(1)
-
- if self.answer_len is None:
- self.answer_len = answers.size(1)
-
- assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
- result = []
-
- for prompt, answer in zip(prompts, answers):
- result.append(torch.cat([prompt, answer], dim=0)[None, :])
-
- return torch.cat(result, dim=0)
-
def __init__(
self,
problem,
):
super().__init__()
- self.nb_token_values = problem.nb_token_values()
-
self.problem = problem
self.back_accuracy = back_accuracy
self.batch_size = batch_size
self.train_c_quizzes = []
self.test_c_quizzes = []
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- quizzes,
- mistakes=None,
- show_part_to_predict=True,
- ):
- quizzes = quizzes.clone().to("cpu")
- i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
- p2a = quizzes[i_p2a]
- a2p = quizzes[i_a2p]
- assert p2a.size(0) + a2p.size(0) == quizzes.size(0)
- quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p])
-
- if show_part_to_predict:
- predicted_prompts = i_a2p.long()
- predicted_answers = 1 - predicted_prompts
- if mistakes is not None:
- # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes.to("cpu")
- predicted_answers *= mistakes.to("cpu")
- else:
- # 0/2 ~ not-to-predict / to predict
- predicted_prompts *= 2
- predicted_answers *= 2
- else:
- predicted_prompts = None
- predicted_answers = None
-
- self.problem.save_quiz_illustrations(
- result_dir,
- filename_prefix,
- quizzes[:, : self.prompt_len],
- quizzes[:, self.prompt_len :],
- predicted_prompts,
- predicted_answers,
- )
-
def vocabulary_size(self):
- return self.nb_token_values
+ return self.problem.nb_token_values
######################################################################
def produce_results(
self, n_epoch, model, input, result_dir, deterministic_synthesis
):
- def compute_accuracy(input, log_prefix=None):
- input = input.to(self.device)
- ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
+ def predict(input, struct, mask):
+ ar_mask = self.problem.make_ar_mask(
+ quizzes=quizzes, struct=struct, mask=mask
+ )
+ result = quizzes * (1 - ar_mask)
+ seq_logproba = torch.empty(fwd_quizzes, device=self.device)
masked_inplace_autoregression(
model=model,
device=self.device,
)
- correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
-
- i_p2a, i_a2p = self.indices_p2a_and_a2p(input)
-
- correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values
-
- if self.back_accuracy and i_a2p.any():
- # accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.problem.p_a_flip(result[i_a2p])
- back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len]
- _, correct[i_a2p] = compute_accuracy(back_input)
-
- if log_prefix is not None:
- p2a_nb_correct = correct[i_p2a].sum()
- p2a_nb_total = correct[i_p2a].size(0)
- a2p_nb_correct = correct[i_a2p].sum()
- a2p_nb_total = correct[i_a2p].size(0)
-
- self.logger(
- f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}"
- )
+ nb_correct = (result == quizzes).min(dim=1).long()
return result, correct
- test_result, test_correct = compute_accuracy(input, log_prefix="test")
+ input = input.to(self.device)
+ i = self.problem.indices_select(quizzes=input, struct=struct)
- n_test_p2a = input[:, 0] == self.problem.token_forward
+ test_result_fwd, test_correct_fwd = predict(
+ input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ )
- p2a_test_correct = test_correct[n_test_p2a]
+ input_bck = self.problem.reconfigure(
+ predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0],
+ struct=("A", "f_A", "B", "f_B"),
+ )
+
+ l = input_bck.size(1)
+ input_bck[:, 3 * l :] = input[i == False][:, :l]
+ test_result_bck, test_correct_bck = predict(
+ input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ )
- main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0)
+ main_test_accuracy = test_correct.sum() / test_correct.size(0)
##############################
######################################################################
- def create_w_quizzes(
- self, model, nb_train_samples, nb_test_samples, p2a_only=False
- ):
- model.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- model.test_w_quizzes = self.generate_token_sequences(nb_test_samples)
+ def flip_half_in_place(self, quizzes):
+ r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5
+ i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B"))
+ quizzes[i & r] = self.problem.reconfigure(
+ quizzes[i & r], struct=("f_B", "f_A", "B", "A")
+ )
+ j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A"))
+ quizzes[j & r] = self.problem.reconfigure(
+ quizzes[j & r], struct=("A", "f_A", "B", "f_B")
+ )
- if not p2a_only:
- self.p_a_flip_half_in_place(model.train_w_quizzes)
- self.p_a_flip_half_in_place(model.test_w_quizzes)
+ def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
+ model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
+ model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
+
+ self.flip_half_in_place(model.train_w_quizzes)
+ self.flip_half_in_place(model.test_w_quizzes)
######################################################################
- def renew_train_w_quizzes(self, model, p2a_only=False):
+ def renew_train_w_quizzes(self, model):
if hasattr(model, "hard_w_quizzes"):
self.logger(
f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
model.train_w_quizzes[...] = torch.cat(
[
model.hard_w_quizzes,
- self.generate_token_sequences(
+ self.problem.generate_w_quizzes(
model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
),
],
dim=0,
)
else:
- model.train_w_quizzes[...] = self.generate_token_sequences(
+ model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
model.train_w_quizzes.size(0)
)
- if not p2a_only:
- self.p_a_flip_half_in_place(model.train_w_quizzes)
+ self.flip_half_in_place(model.train_w_quizzes)
######################################################################
# -------------------------------
# f(A), A, f(B) | B
- c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to(
- self.device
- )
+ c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device)
result = c_quizzes.clone()
ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
self,
nb,
model_for_generation,
- p2a_only=False,
temperature_hot=1.0,
temperature_cold=1.0,
):
# )
# lt_clean = None
- if p2a_only:
- c_quizzes[...] = self.problem.token_forward
+ c_quizzes[...] = self.problem.token_backward
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- else:
- c_quizzes[...] = self.problem.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
- c_quizzes = self.problem.p_a_flip(c_quizzes)
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
return c_quizzes.to("cpu")