parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-parser.add_argument("--bidirectional_validation", action="store_true", default=False)
-
parser.add_argument("--problem", type=str, default="sky")
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
-parser.add_argument("--dirty_debug", action="store_true", default=False)
-
parser.add_argument("--generation_temperature", type=float, default=2.0)
parser.add_argument("--deterministic_validation", action="store_true", default=False)
+parser.add_argument("--bidirectional_validation", action="store_true", default=False)
+
+parser.add_argument("--dirty_debug", action="store_true", default=False)
+
######################################################################
parser.add_argument("--sky_height", type=int, default=6)
args = parser.parse_args()
if args.min_to_validate is None:
- args.min_to_validate = args = nb_gpts - 1
+ args.min_to_validate = args.nb_gpts - 1
if args.max_to_validate is None:
- args.max_to_validate = args = nb_gpts - 1
+ args.max_to_validate = args.nb_gpts - 1
if args.result_dir is None:
args.result_dir = f"results_culture"
class QuizzMachine:
- def make_ar_mask(self, input):
- b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
+ def make_ar_mask(self, input, first, nb):
+ i = torch.arange(input.size(1), device=input.device)
+ b = torch.logical_and(i >= first, i < first + nb)
return b.long()[None, :].expand_as(input)
def generate_token_sequences(self, nb):
prompts, answers = self.problem.generate_prompts_and_answers(nb)
+
+ if self.prompt_len is None:
+ self.prompt_len = prompts.size(1)
+
+ if self.answer_len is None:
+ self.answer_len = answers.size(1)
+
+ assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
+
result = []
for prompt, answer in zip(prompts, answers):
if torch.rand(1) < 0.5:
- a = [torch.tensor([self.token_forward]), prompt, answer]
+ a = [
+ torch.tensor([self.token_forward]),
+ prompt,
+ torch.tensor([self.token_forward]),
+ answer,
+ ]
else:
- a = [torch.tensor([self.token_backward]), answer, prompt]
+ a = [
+ torch.tensor([self.token_backward]),
+ answer,
+ torch.tensor([self.token_backward]),
+ prompt,
+ ]
result.append(torch.cat(a, dim=0)[None, :])
self.batch_size = batch_size
self.device = device
self.logger = logger
+ self.prompt_len = None
+ self.answer_len = None
self.train_w_quizzes = self.generate_token_sequences(nb_train_samples).to(
device
):
def compute_accuracy(input):
input = input[:nmax]
- ar_mask = self.make_ar_mask(input)
+ ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
device=self.device,
)
- nb_total, nb_correct = (
- input.size(0),
- (input == result).long().min(dim=1).values.sum(),
- )
+ nb_total = input.size(0)
+ nb_correct = (input == result).long().min(dim=1).values.sum()
return nb_total, nb_correct
##############################
input = self.test_w_quizzes[:96]
- ar_mask = self.make_ar_mask(input)
+ ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
else:
self.test_c_quizzes.append(new_c_quizzes)
- def reverse_time(self, c_quizzes):
- l = (c_quizzes.size(1) - 1) // 2
- direction = c_quizzes[:, 0:1]
- direction = self.token_forward * (
- direction == self.token_backward
- ) + self.token_backward * (direction == self.token_forward)
+ def forward_to_backward(self, c_quizzes):
+ prompts = c_quizzes[:, 1 : 1 + self.prompt_len]
+ answers = c_quizzes[:, 2 + self.prompt_len :]
+ return torch.cat(
+ [
+ c_quizzes.new_full((c_quizzes, 1), self.token_backward),
+ answers,
+ c_quizzes.new_full((c_quizzes, 1), self.token_backward),
+ prompts,
+ ],
+ dim=1,
+ )
+ def backward_to_forward(self, c_quizzes):
+ answers = c_quizzes[:, 1 : 1 + self.answer_len :]
+ prompts = c_quizzes[:, 2 + self.answer_len :]
return torch.cat(
- [direction, c_quizzes[:, l + 1 :], c_quizzes[:, 1 : l + 1]], dim=1
+ [
+ c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
+ prompts,
+ c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
+ answers,
+ ],
+ dim=1,
)
def compute_correctness(
bidirectional_validation=False,
deterministic_validation=True,
):
- reversed_c_quizzes = self.reverse_time(c_quizzes)
+ if bidirectional_validation:
+ backward_c_quizzes = self.forward_to_backward(c_quizzes)
- ar_mask = self.make_ar_mask(c_quizzes)
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
device=self.device,
)
- # Check how many of models can solve the quizzes in both directions
-
nb_correct = 0
for model in models_for_validation:
seq_logproba[...] = 0.0
+ ar_mask = self.make_ar_mask(result, 2 + self.prompt_len, self.answer_len)
+
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
correct = (c_quizzes == result).long().min(dim=-1).values
if bidirectional_validation:
- reversed_result = reversed_c_quizzes.clone()
+ backward_result = backward_c_quizzes.clone()
+
+ ar_mask = self.make_ar_mask(
+ backward_result, 2 + self.answer_len, self.prompt_len
+ )
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=reversed_result,
+ input=backward_result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
temperature=1.0,
deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving reversed c_quizzes",
+ # progress_bar_desc="solving backward c_quizzes",
device=self.device,
)
- reversed_correct = (
- (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
+ backward_correct = (
+ (backward_c_quizzes == backward_result).long().min(dim=-1).values
)
- correct *= reversed_correct
+ correct *= backward_correct
# endif
# Then we return the quizz, and re-generate the response, now
# at low temperature
- c_quizzes = self.reverse_time(c_quizzes)
+ c_quizzes = self.backward_to_forward(c_quizzes)
masked_inplace_autoregression(
model=model_for_generation,