)
else:
flipped_from_forward = torch.cat(
- [
- quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
- quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1],
- quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
- quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1],
- ],
+ [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]],
dim=1,
)
flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
flipped_from_backward = torch.cat(
- [
- quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1],
- quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
- quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1],
- quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
- ],
- dim=1,
+ [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1
)
flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
- m = (flipped[:, 0] == self.token_forward).long()
+ m = (quizzes[:, 0] == self.token_forward).long()[:, None]
flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
return c_quizzes.to("cpu")
######################################################################
-
- def generate_c_quizzes_fixed_point(
- self,
- nb,
- model_for_generation,
- p2a_only=False,
- temperature_hot=1.0,
- temperature_cold=1.0,
- ):
- c_quizzes = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- seq_logproba = torch.zeros(nb, device=self.device)
-
- lt_noisy = lambda s, logits: logits / temperature_hot
- lt_clean = lambda s, logits: logits / temperature_cold
-
- c_quizzes[...] = self.problem.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes)
-
- c_quizzes = self.problem.p_a_flip(c_quizzes)
-
- while True:
- print("ITERATION")
-
- c_quizzes = self.problem.p_a_flip(c_quizzes)
-
- pred = c_quizzes.clone()
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- c_quizzes = self.problem.p_a_flip(c_quizzes)
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- if pred[202:].equal(c_quizzes[202:]):
- break
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes)
-
- exit(0)
-
- return c_quizzes.to("cpu")
-
- ######################################################################
-
- def generate_c_quizzes_mixing(
- self,
- nb,
- model_for_generation,
- p2a_only=False,
- temperature_hot=1.0,
- temperature_cold=1.0,
- ):
- c_quizzes = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- c_quizzes_1 = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- c_quizzes_2 = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- seq_logproba = torch.zeros(nb, device=self.device)
-
- lt_noisy = lambda s, logits: logits / temperature_hot
- lt_clean = lambda s, logits: logits / temperature_cold
-
- ######################################################################
-
- c_quizzes_1[...] = self.problem.token_backward
- ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes_1,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
-
- c_quizzes_2[...] = self.problem.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes_2,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
-
- h = len(model_for_generation.trunk) // 2
-
- with torch.autograd.no_grad():
- t = model_for_generation.training
- model_for_generation.eval()
-
- bs1 = model_for_generation.partial_forward(
- mygpt.BracketedSequence(c_quizzes_1), end_layer=h
- )
- bs2 = model_for_generation.partial_forward(
- mygpt.BracketedSequence(c_quizzes_2), end_layer=h
- )
-
- alpha = 0.1
-
- output = model_for_generation.partial_forward(
- mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
- start_layer=h,
- ).x
-
- dist = torch.distributions.categorical.Categorical(logits=output)
- c_quizzes[...] = dist.sample()
-
- c_quizzes[...] = (
- ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
- )
-
- model_for_generation.train(t)
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
-
- ######################################################################
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
-
- c_quizzes = self.problem.p_a_flip(c_quizzes)
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
-
- print("DONE")
- exit(0)
-
- return c_quizzes.to("cpu")