X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=6b46fa0f23b68ca2c88466be37b60c069758bce6;hb=e0ab20005e2578edff27d4246c6904cf1047ed22;hp=02e1a8d6698a58d84c92001b8a06bfba2a461d47;hpb=46645637edb8a39ed6a674696f9d78cc4603b805;p=culture.git diff --git a/main.py b/main.py index 02e1a8d..6b46fa0 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, reasoning, quiz_machine +import sky, grids, quiz_machine # world quizzes vs. culture quizzes @@ -79,7 +79,7 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument("--problem", type=str, default="sky") +parser.add_argument("--problem", type=str, default="grids") parser.add_argument("--nb_gpts", type=int, default=5) @@ -251,8 +251,8 @@ if args.problem == "sky": speed=args.sky_speed, ) back_accuracy = False -elif args.problem == "reasoning": - problem = reasoning.Reasoning(device=device) +elif args.problem == "grids": + problem = grids.Grids(device=device) back_accuracy = True else: raise ValueError @@ -418,6 +418,7 @@ def create_c_quizzes( ) file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + with open(file_name, "w") as logp_file: while ( valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)