From 042f645a83bf7e4513315d583e923a382ecc6011 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 11 Aug 2024 22:36:22 +0200 Subject: [PATCH] Update. --- main.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 40772c2..cd6e3a9 100755 --- a/main.py +++ b/main.py @@ -78,7 +78,7 @@ parser.add_argument("--nb_heads", type=int, default=None) parser.add_argument("--nb_blocks", type=int, default=None) -parser.add_argument("--dropout", type=float, default=0.1) +parser.add_argument("--dropout", type=float, default=0.5) # ---------------------------------- parser.add_argument("--deterministic_synthesis", action="store_true", default=False) @@ -93,13 +93,15 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--min_succeed_to_validate", type=int, default=2) + parser.add_argument("--max_fail_to_validate", type=int, default=3) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_understands", type=float, default=0.95) -parser.add_argument("--proba_not_understands", type=float, default=0.1) +parser.add_argument("--proba_not_understands", type=float, default=0.5) parser.add_argument("--temperature_hot", type=float, default=1.5) @@ -663,7 +665,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) to_keep = ( - (nb_succeed + nb_fail == probas.size(1)) + # (nb_succeed + nb_fail == probas.size(1)) + (nb_succeed >= args.min_succeed_to_validate) & (nb_fail >= 1) & (nb_fail <= args.max_fail_to_validate) ) -- 2.39.5