else:
self.test_c_quizzes.append(new_c_quizzes)
- def comput_correctness(self, c_quizzes, models_for_validation):
- # Create the reverse quizzes
-
+ def reverse_time(self, c_quizzes):
token_forward, token_backward = self.problem.direction_tokens()
l = (c_quizzes.size(1) - 1) // 2
direction = self.problem.token_forward * (
direction == self.problem.token_backward
) + self.problem.token_backward * (direction == self.problem.token_forward)
- reverse_c_quizzes = torch.cat(
- [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
- )
+
+ return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
+
+ def comput_correctness(self, c_quizzes, models_for_validation):
+ reversed_c_quizzes = self.reverse_time(c_quizzes)
ar_mask = self.make_ar_mask(c_quizzes)
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
correct = (c_quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_c_quizzes.clone()
+ reversed_result = reversed_c_quizzes.clone()
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=reverse_result,
+ input=reversed_result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
device=self.device,
)
- reverse_correct = (
- (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
+ reversed_correct = (
+ (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
)
- nb_correct += correct * reverse_correct
+ nb_correct += correct * reversed_correct
return nb_correct
###############################################################
- def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
+ def generate_quizzes(
+ self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
+ ):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
- warnings.warn("noise injection", RuntimeWarning)
+ # warnings.warn("noise injection", RuntimeWarning)
temperature = 1
- noise_std = torch.rand(1).item()
- self.logger(f"{noise_std=}")
- mygpt.set_noise_injection(model_for_generation, noise_std)
+ # noise_std = torch.rand(1).item()
+ # self.logger(f"{noise_std=}")
+
+ # mygpt.set_noise_injection(model_for_generation, noise_std)
masked_inplace_autoregression(
model=model_for_generation,
device=self.device,
)
+ # mygpt.set_noise_injection(model_for_generation, 0.0)
+
ave_seq_logproba = seq_logproba.mean()
masked_inplace_autoregression(
device=self.device,
)
- mygpt.set_noise_injection(model_for_generation, 0.0)
+ if reverse_cleanup:
+ c_quizzes = self.reverse_time(c_quizzes)
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
return c_quizzes, seq_logproba.mean()