######################################################################
-
-class Gang(nn.Module):
- def __init__(self, models, nb_models_for_generation, mode="groupthink"):
- super().__init__()
- self.models = nn.ModuleList(models)
- self.nb_models_for_generation = nb_models_for_generation
- self.mode = mode
-
- def forward(self, bs):
- # If first = 0, we are re-starting an auto-regressive process,
- # that's the right moment to randomize who gonna do it
- if bs.first == 0:
- self.models_to_use = [
- self.models[k]
- for k in torch.randperm(len(self.models))[
- : self.nb_models_for_generation
- ]
- ]
-
- all_the_logits = torch.cat(
- [model(bs).x[None] for model in self.models_to_use], dim=0
- )
-
- if self.mode == "groupthink":
- y = all_the_logits.mean(dim=0)
- elif self.mode == "groupwork":
- m = torch.rand(all_the_logits.size(), device=all_the_logits.device)
- m = (m.sort(dim=0).indices == 0).long()
- y = (y * m).sum(dim=0)
- else:
- raise ValueError(f"Invalid mode {self.mode}")
-
- return BracketedSequence(y, bs.first, bs.nb)
-
-
-######################################################################
-
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
# 1s where tokens should be generated. The others are kept
# unchanged.
###############################################################
- def generate_quizzes(
- self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
- ):
+ def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+ if reverse_cleanup:
+ warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+ temperature = 10.0
+ else:
+ temperature = 1.0
+
# warnings.warn("noise injection", RuntimeWarning)
- temperature = 1
# noise_std = torch.rand(1).item()
# self.logger(f"{noise_std=}")
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
- # progress_bar_desc="sampling c_quizzes",
device=self.device,
)
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=True,
- # progress_bar_desc="sampling c_quizzes",
device=self.device,
)
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=True,
- # progress_bar_desc="sampling c_quizzes",
device=self.device,
)
return c_quizzes, seq_logproba.mean()
-
- ######################################################################
-
- def create_c_quizzes(
- self,
- nb,
- model_for_generation,
- models_for_validation,
- min_ave_seq_logproba,
- reverse_cleanup,
- n_epoch,
- result_dir,
- ):
- c_quizzes, ave_seq_logproba = self.generate_quizzes(
- nb,
- model_for_generation=model_for_generation,
- min_ave_seq_logproba=min_ave_seq_logproba,
- reverse_cleanup=reverse_cleanup,
- )
-
- nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
-
- return c_quizzes, nb_correct, ave_seq_logproba
-
- ######################################################################
-
- def gang_create_c_quizzes(
- self,
- nb,
- nb_models_for_generation,
- models,
- mode,
- min_ave_seq_logproba,
- n_epoch,
- result_dir,
- ):
- model_for_generation = Gang(models, nb_models_for_generation, mode)
- models_for_validation = models
- return self.create_c_quizzes(
- nb,
- model_for_generation,
- models_for_validation,
- min_ave_seq_logproba,
- n_epoch,
- result_dir,
- )