X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=be0d8e0a0871b2283f7bbecc5307b9a5a494262a;hb=90f27333118e72b35068d7d7ac29e7b14f27aa3b;hp=918f75d8b1e0ff4ae3f04f3310026ab82afd81ea;hpb=64abc9f3a07a8211f308271fde7d8f876a968ab5;p=culture.git diff --git a/main.py b/main.py index 918f75d..be0d8e0 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, wireworld, quizz_machine +import sky, reasoning, quizz_machine # world quizzes vs. culture quizzes @@ -79,23 +79,23 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument("--both_directions", action="store_true", default=False) - parser.add_argument("--problem", type=str, default="sky") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=4) +parser.add_argument("--min_to_validate", type=int, default=None) -parser.add_argument("--max_to_validate", type=int, default=4) +parser.add_argument("--max_to_validate", type=int, default=None) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) -parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument("--generation_temperature", type=float, default=2.0) + +parser.add_argument("--deterministic_validation", action="store_true", default=False) -parser.add_argument("--generation_temperature", type=float, default=1.0) +parser.add_argument("--bidirectional_validation", action="store_true", default=False) -parser.add_argument("--stochastic_validation", action="store_true", default=False) +parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### @@ -113,6 +113,12 @@ parser.add_argument("--sky_speed", type=int, default=3) args = parser.parse_args() +if args.min_to_validate is None: + args.min_to_validate = args.nb_gpts - 1 + +if args.max_to_validate is None: + args.max_to_validate = args.nb_gpts - 1 + if args.result_dir is None: args.result_dir = f"results_culture" @@ -243,8 +249,10 @@ if args.problem == "sky": nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, ) -elif args.problem == "wireworld": - problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5) + back_accuracy = False +elif args.problem == "reasoning": + problem = reasoning.Reasoning(device=device) + back_accuracy = True else: raise ValueError @@ -252,6 +260,7 @@ quizz_machine = quizz_machine.QuizzMachine( problem=problem, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, + back_accuracy=back_accuracy, batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, @@ -423,8 +432,8 @@ def create_c_quizzes( nb_correct, seq_logproba = quizz_machine.compute_correctness( c_quizzes, models, - both_directions=args.both_directions, - deterministic_validation=not args.stochastic_validation, + bidirectional_validation=args.bidirectional_validation, + deterministic_validation=args.deterministic_validation, ) for n, l in zip(nb_correct, seq_logproba):