Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 50e5611..6b46fa0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, reasoning, quiz_machine
+import sky, grids, quiz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -79,7 +79,7 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument("--problem", type=str, default="sky")
+parser.add_argument("--problem", type=str, default="grids")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
@@ -126,9 +126,9 @@ if args.result_dir is None:
 
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
+    args.nb_gpts = 2
     nb_new_c_quizzes_for_train = 100
     nb_new_c_quizzes_for_test = 10
-    args.nb_gpts = 2
 
 ######################################################################
 
@@ -251,8 +251,8 @@ if args.problem == "sky":
         speed=args.sky_speed,
     )
     back_accuracy = False
-elif args.problem == "reasoning":
-    problem = reasoning.Reasoning(device=device)
+elif args.problem == "grids":
+    problem = grids.Grids(device=device)
     back_accuracy = True
 else:
     raise ValueError
@@ -418,6 +418,7 @@ def create_c_quizzes(
     )
 
     file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+
     with open(file_name, "w") as logp_file:
         while (
             valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)