for n in range(nb_correct.max() + 1):
recorded[n].append(new_c_quizzes[nb_correct == n].clone())
- nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+ nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+ nv = " ".join([str(x.item()) for x in nv])
log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
- warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
- temperature = 10
+ if reverse_cleanup:
+ warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+ temperature = 10.0
+ else:
+ temperature = 1.0
# warnings.warn("noise injection", RuntimeWarning)
# noise_std = torch.rand(1).item()