input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-def masked_inplace_autoregression(
- model,
- batch_size,
- input,
- ar_mask,
- seq_logproba,
- logit_transformer=None,
- deterministic_synthesis=False,
- forbidden_tokens=None,
- logit_biases=None,
- progress_bar_desc=None,
- device=torch.device("cpu"),
-):
- assert input.size() == ar_mask.size()
-
- batches = zip(
- input.split(batch_size),
- ar_mask.split(batch_size),
- seq_logproba.split(batch_size),
- )
-
- if progress_bar_desc is not None:
- batches = tqdm.tqdm(
- batches,
- dynamic_ncols=True,
- desc=progress_bar_desc,
- total=(input.size(0) + batch_size - 1) // batch_size,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, ar_mask, seq_logproba in batches:
- one_batch_masked_inplace_autoregression(
- model=model,
- input=input,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- logit_transformer=logit_transformer,
- deterministic_synthesis=deterministic_synthesis,
- )
-
- model.train(t)
-
-
######################################################################
######################################################################
+ def autoregression(
+ model,
+ input,
+ ar_mask,
+ seq_logproba=None,
+ logit_transformer=None,
+ progress_bar_desc=None,
+ ):
+ assert input.size() == ar_mask.size()
+
+ if seq_logproba is None:
+ seq_logproba = torch.empty(input.size(0), device=self.device)
+
+ batches = zip(
+ input.split(self.batch_size),
+ ar_mask.split(self.batch_size),
+ seq_logproba.split(self.batch_size),
+ )
+
+ if progress_bar_desc is not None:
+ batches = tqdm.tqdm(
+ batches,
+ dynamic_ncols=True,
+ desc=progress_bar_desc,
+ total=(input.size(0) + self.batch_size - 1) // self.batch_size,
+ )
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, ar_mask, seq_logproba in batches:
+ one_batch_masked_inplace_autoregression(
+ model=model,
+ input=input,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ logit_transformer=logit_transformer,
+ deterministic_synthesis=deterministic_synthesis,
+ )
+
+ model.train(t)
+
+ ######################################################################
+
def data_input(self, model, split="train"):
assert split in {"train", "test"}
ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
- seq_logproba = torch.empty(quizzes.size(0), device=self.device)
-
- masked_inplace_autoregression(
+ self.autoregression(
model=model,
- batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
progress_bar_desc="accuracy",
- device=self.device,
)
correct = (result == quizzes).min(dim=1).values.long()
result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
)
- masked_inplace_autoregression(
+ self.autoregression(
model=model,
- batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- device=self.device,
)
correct = (c_quizzes == result).long().min(dim=-1).values
result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
)
- masked_inplace_autoregression(
+ self.autoregression(
model=model,
- batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- device=self.device,
)
correct *= (reversed_c_quizzes == result).long().min(dim=-1).values
model_for_generation,
temperature_hot=1.0,
temperature_cold=1.0,
+ to_recycle=None,
):
c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B"))
c_quizzes = c_quizzes.to(self.device)
lt_noisy = lambda s, logits: logits / temperature_hot
lt_clean = lambda s, logits: logits / temperature_cold
- masked_inplace_autoregression(
+ self.autoregression(
model=model_for_generation,
- batch_size=self.batch_size,
input=c_quizzes,
ar_mask=self.make_ar_mask(
c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
),
seq_logproba=seq_logproba,
logit_transformer=lt_noisy,
- device=self.device,
)
- masked_inplace_autoregression(
+ if to_recycle is not None:
+ l = c_quizzes.size(1) // 4
+ self.logger(f"recycling {to_recycle.size(0)} rejected quizzes")
+ c_quizzes[: to_recycle.size(0), :l] = to_recycle[:, 3 * l :]
+
+ self.autoregression(
model=model_for_generation,
- batch_size=self.batch_size,
input=c_quizzes,
ar_mask=self.make_ar_mask(
c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
- device=self.device,
)
c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
- masked_inplace_autoregression(
+ self.autoregression(
model=model_for_generation,
- batch_size=self.batch_size,
input=c_quizzes,
ar_mask=self.make_ar_mask(
c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
- device=self.device,
)
return c_quizzes.to("cpu")
lt_noisy = lambda s, logits: logits / temperature_hot
- masked_inplace_autoregression(
+ self.autoregression(
model=model_for_generation,
- batch_size=self.batch_size,
input=c_quizzes,
ar_mask=self.make_ar_mask(
c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 1, 1)
),
seq_logproba=seq_logproba,
logit_transformer=lt_noisy,
- device=self.device,
+ )
+
+ return c_quizzes.to("cpu")
+
+ ######################################################################
+
+ def generate_c_quizzes_2(
+ self,
+ nb,
+ model_for_generation,
+ temperature_hot=1.0,
+ temperature_cold=1.0,
+ ):
+ warnings.warn(
+ "**************************** simple quiz generation", RuntimeWarning
+ )
+
+ seq_logproba = torch.zeros(nb, device=self.device)
+
+ lt_noisy = lambda s, logits: logits / temperature_hot
+ lt_clean = lambda s, logits: logits / temperature_cold
+
+ c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
+ c_quizzes = c_quizzes.to(self.device)
+
+ self.autoregression(
+ model=model_for_generation,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(
+ c_quizzes, ("A", "f_A", "B", "f_B"), (1, 1, 0, 0)
+ ),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ )
+
+ c_quizzes2 = self.problem.create_empty_quizzes(nb, ("B", "f_B", "A", "f_A"))
+ c_quizzes2 = c_quizzes2.to(self.device)
+
+ self.autoregression(
+ model=model_for_generation,
+ input=c_quizzes2,
+ ar_mask=self.make_ar_mask(
+ c_quizzes2,
+ ("B", "f_B", "A", "f_A"),
+ (1, 0, 0, 0),
+ ),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ )
+
+ l = c_quizzes.size(1) // 4
+ c_quizzes[:, 2 * l : 3 * l] = c_quizzes2[:, :l]
+
+ self.autoregression(
+ model=model_for_generation,
+ input=c_quizzes,
+ ar_mask=self.make_ar_mask(
+ c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
)
return c_quizzes.to("cpu")