X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=6b00bbfd991178841468b75006f94121668c2b4f;hb=refs%2Fheads%2Fmaster;hp=7ba5193685e150bbf7870f911a7be4115f9f3558;hpb=982438ec146974f415072ff98523503fc8721538;p=culture.git diff --git a/main.py b/main.py index 7ba5193..6b00bbf 100755 --- a/main.py +++ b/main.py @@ -84,11 +84,11 @@ parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) -parser.add_argument("--proba_understands", type=float, default=0.99) +parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=2.0) +parser.add_argument("--generation_temperature", type=float, default=1.0) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -373,13 +373,16 @@ def one_epoch(model, quiz_machine, local_device=main_device): # This is the key routine that decides what generated quizzes to keep -def compute_valid_quizzes(token_logprobas): +# token_logprobas are NxMxT where M is the number of models + + +def compute_valid_quizzes_(token_logprobas): warnings.warn("validation with uniform constraints", RuntimeWarning) l = token_logprobas.min(dim=-1).values.sort(dim=-1).values return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) -def compute_valid_quizzes_(token_logprobas): +def compute_valid_quizzes(token_logprobas): l = token_logprobas.sum(dim=-1).sort(dim=-1).values return (l[:, 0] < math.log(args.proba_not_understands)) & ( l[:, 1] > math.log(args.proba_understands) @@ -617,6 +620,10 @@ for n_epoch in range(args.nb_epochs): quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) log_string(f"wrote {filename}") + # Force one epoch of training + for model in models: + model.main_test_accuracy = 0.0 + ################################################## # Select, improve, and eval the worst model