Update.
[culture.git] / quizz_machine.py
index 7b0b877..6f7492d 100755 (executable)
@@ -386,8 +386,13 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
+        if reverse_cleanup:
+            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+            temperature = 10.0
+        else:
+            temperature = 1.0
+
         # warnings.warn("noise injection", RuntimeWarning)
-        temperature = 1
         # noise_std = torch.rand(1).item()
         # self.logger(f"{noise_std=}")
 
@@ -445,11 +450,15 @@ class QuizzMachine:
         model_for_generation,
         models_for_validation,
         min_ave_seq_logproba,
+        reverse_cleanup,
         n_epoch,
         result_dir,
     ):
         c_quizzes, ave_seq_logproba = self.generate_quizzes(
-            nb, model_for_generation, min_ave_seq_logproba
+            nb,
+            model_for_generation=model_for_generation,
+            min_ave_seq_logproba=min_ave_seq_logproba,
+            reverse_cleanup=reverse_cleanup,
         )
 
         nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
@@ -465,16 +474,18 @@ class QuizzMachine:
         models,
         mode,
         min_ave_seq_logproba,
+        reverse_cleanup,
         n_epoch,
         result_dir,
     ):
         model_for_generation = Gang(models, nb_models_for_generation, mode)
         models_for_validation = models
         return self.create_c_quizzes(
-            nb,
-            model_for_generation,
-            models_for_validation,
-            min_ave_seq_logproba,
-            n_epoch,
-            result_dir,
+            nb=nb,
+            model_for_generation=model_for_generation,
+            models_for_validation=models_for_validation,
+            min_ave_seq_logproba=min_ave_seq_logproba,
+            reverse_cleanup=reverse_cleanup,
+            n_epoch=n_epoch,
+            result_dir=result_dir,
         )