X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=6b46fa0f23b68ca2c88466be37b60c069758bce6;hb=e0ab20005e2578edff27d4246c6904cf1047ed22;hp=597ec32f71f5232ddbbc7d26514ce1d238112816;hpb=ceddc8cc3adbb045fdef1ccb0b3df2b8fed9eb4c;p=culture.git diff --git a/main.py b/main.py index 597ec32..6b46fa0 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, reasoning, quiz_machine +import sky, grids, quiz_machine # world quizzes vs. culture quizzes @@ -79,7 +79,7 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument("--problem", type=str, default="sky") +parser.add_argument("--problem", type=str, default="grids") parser.add_argument("--nb_gpts", type=int, default=5) @@ -126,6 +126,7 @@ if args.result_dir is None: if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 + args.nb_gpts = 2 nb_new_c_quizzes_for_train = 100 nb_new_c_quizzes_for_test = 10 @@ -250,8 +251,8 @@ if args.problem == "sky": speed=args.sky_speed, ) back_accuracy = False -elif args.problem == "reasoning": - problem = reasoning.Reasoning(device=device) +elif args.problem == "grids": + problem = grids.Grids(device=device) back_accuracy = True else: raise ValueError @@ -417,6 +418,7 @@ def create_c_quizzes( ) file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + with open(file_name, "w") as logp_file: while ( valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)