From: François Fleuret Date: Wed, 25 Sep 2024 06:11:55 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=5c6b9083c598cdd665a09a8e774d8b288a951dc2;p=culture.git Update. --- diff --git a/main.py b/main.py index de04d5b..87e73ad 100755 --- a/main.py +++ b/main.py @@ -45,23 +45,6 @@ parser.add_argument("--train_batch_size", type=int, default=None) parser.add_argument("--eval_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=50000) - - -parser.add_argument("--nb_test_samples", type=int, default=10000) - -parser.add_argument("--nb_c_quizzes", type=int, default=5000) - -parser.add_argument("--c_quiz_multiplier", type=int, default=1) - -parser.add_argument("--learning_rate", type=float, default=5e-4) - -parser.add_argument("--nb_have_to_be_correct", type=int, default=3) - -parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) - -parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) - # ---------------------------------- parser.add_argument("--model_type", type=str, default="standard") @@ -80,6 +63,8 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.5) +parser.add_argument("--learning_rate", type=float, default=5e-4) + # ---------------------------------- parser.add_argument("--nb_threads", type=int, default=1) @@ -88,21 +73,35 @@ parser.add_argument("--gpus", type=str, default="all") # ---------------------------------- -parser.add_argument("--nb_models", type=int, default=5) - -parser.add_argument("--proba_plasticity", type=float, default=0.25) +parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--diffusion_nb_iterations", type=int, default=25) parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05) +parser.add_argument("--nb_train_samples", type=int, default=50000) + +parser.add_argument("--nb_test_samples", type=int, default=10000) + +parser.add_argument("--nb_c_quizzes", type=int, default=5000) + +# ---------------------------------- + +parser.add_argument("--nb_models", type=int, default=5) + +parser.add_argument("--nb_have_to_be_correct", type=int, default=3) + +parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) + +parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) + +parser.add_argument("--proba_plasticity", type=float, default=0.0) + parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_prompt_noise", type=float, default=0.05) -parser.add_argument("--proba_hint", type=float, default=0.25) - -parser.add_argument("--quizzes", type=str, default=None) +parser.add_argument("--proba_hints", type=float, default=0.1) ###################################################################### @@ -318,7 +317,7 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): def add_hints_imt(imt_set): """Set every component of the mask to zero with probability - args.proba_hint, and for each component set to zero, copy the + args.proba_hints, and for each component set to zero, copy the corresponding value from the target into the input """ @@ -327,8 +326,9 @@ def add_hints_imt(imt_set): # t = h.sort(dim=1).values[:, args.nb_hints, None] # mask_hints = (h < t).long() mask_hints = ( - torch.rand(input.size(), device=input.device) < args.proba_hint + torch.rand(input.size(), device=input.device) < args.proba_hints ).long() * masks + masks = (1 - mask_hints) * masks input = (1 - mask_hints) * input + mask_hints * targets return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) @@ -610,8 +610,8 @@ def one_complete_epoch( quizzes = generate_quiz_set( args.nb_test_samples, - # c_quizzes=None, - c_quizzes=test_c_quizzes, + c_quizzes=None, + # c_quizzes=test_c_quizzes, c_quiz_multiplier=args.c_quiz_multiplier, ) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) @@ -966,9 +966,9 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ### ### test_c_quizzes = train_c_quizzes[nb_correct >= len(models)//2] ### -### for model in models: -### inject_plasticity(model, args.proba_plasticity) -### model.test_accuracy = 0 +# for model in models: +# inject_plasticity(model, args.proba_plasticity) +# model.test_accuracy = 0 #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for n_epoch in range(current_epoch, args.nb_epochs):