From d29c0fc45415c70aa8860aa08fbfff382020382e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 10 Aug 2024 21:41:33 +0200 Subject: [PATCH] Update. --- main.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index c4dcfb2..0670262 100755 --- a/main.py +++ b/main.py @@ -93,19 +93,19 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--max_fail_to_validate", type=int, default=2) +parser.add_argument("--max_fail_to_validate", type=int, default=3) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_understands", type=float, default=0.95) -parser.add_argument("--proba_not_understands", type=float, default=0.5) +parser.add_argument("--proba_not_understands", type=float, default=0.1) parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=1) -parser.add_argument("--prompt_noise", type=float, default=0.0) +parser.add_argument("--prompt_noise", type=float, default=0.05) parser.add_argument("--nb_averaging_rounds", type=int, default=3) @@ -122,15 +122,15 @@ grids_tasks = ", ".join( parser.add_argument( "--grids_world_tasks", type=str, - default=None, - help="A comma-separated subset of: " + grids_tasks + ", or None for all.", + default="replace_color,translate,grow,frame", + help="A comma-separated subset of: " + grids_tasks + ".", ) parser.add_argument( "--grids_science_tasks", type=str, default=None, - help="A comma-separated subset of: " + grids_tasks + ", or None for all.", + help="A comma-separated subset of: " + grids_tasks + ", or None.", ) ###################################################################### @@ -166,8 +166,8 @@ default_args = { "model": "37M", "batch_size": 25, "inference_batch_size": 50, - "nb_train_samples": 100000, - "nb_test_samples": 10000, + "nb_train_samples": 40000, + "nb_test_samples": 1000, } for k, v in default_args.items(): @@ -492,11 +492,16 @@ def model_transformer_cold(model): # pass +warnings.warn("*********** novel procedure!!! **********", RuntimeWarning) + c_quizzes_procedure = [ + # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 0), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold), # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] @@ -1078,6 +1083,10 @@ if args.test == "quant": ), ) + print(model) + exit(0) + + ###################################################################### current_epoch = 0 -- 2.20.1