device=torch.device("cpu"),
):
assert input.size() == ar_mask.size()
device=torch.device("cpu"),
):
assert input.size() == ar_mask.size()
batch_size=self.batch_size,
input=c_quizzes,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
batch_size=self.batch_size,
input=c_quizzes,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,