Merge branch 'dev'
[culture.git] / main.py
diff --git a/main.py b/main.py
index 7ba5193..6b00bbf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -84,11 +84,11 @@ parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
 
-parser.add_argument("--proba_understands", type=float, default=0.99)
+parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--generation_temperature", type=float, default=1.0)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -373,13 +373,16 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 # This is the key routine that decides what generated quizzes to keep
 
 
-def compute_valid_quizzes(token_logprobas):
+# token_logprobas are NxMxT where M is the number of models
+
+
+def compute_valid_quizzes_(token_logprobas):
     warnings.warn("validation with uniform constraints", RuntimeWarning)
     l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
     return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
 
 
-def compute_valid_quizzes_(token_logprobas):
+def compute_valid_quizzes(token_logprobas):
     l = token_logprobas.sum(dim=-1).sort(dim=-1).values
     return (l[:, 0] < math.log(args.proba_not_understands)) & (
         l[:, 1] > math.log(args.proba_understands)
@@ -617,6 +620,10 @@ for n_epoch in range(args.nb_epochs):
         quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
         log_string(f"wrote {filename}")
 
+        # Force one epoch of training
+        for model in models:
+            model.main_test_accuracy = 0.0
+
     ##################################################
     # Select, improve, and eval the worst model