Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 585cbdf..6b46fa0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, reasoning, quiz_machine
+import sky, grids, quiz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -57,7 +57,7 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
-parser.add_argument("--learning_rate", type=float, default=1e-3)
+parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
 
@@ -79,7 +79,7 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument("--problem", type=str, default="sky")
+parser.add_argument("--problem", type=str, default="grids")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
@@ -126,6 +126,7 @@ if args.result_dir is None:
 
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
+    args.nb_gpts = 2
     nb_new_c_quizzes_for_train = 100
     nb_new_c_quizzes_for_test = 10
 
@@ -250,8 +251,8 @@ if args.problem == "sky":
         speed=args.sky_speed,
     )
     back_accuracy = False
-elif args.problem == "reasoning":
-    problem = reasoning.Reasoning(device=device)
+elif args.problem == "grids":
+    problem = grids.Grids(device=device)
     back_accuracy = True
 else:
     raise ValueError
@@ -417,6 +418,7 @@ def create_c_quizzes(
     )
 
     file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+
     with open(file_name, "w") as logp_file:
         while (
             valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
@@ -484,6 +486,8 @@ def create_c_quizzes(
             quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n
         )[:72]
 
+        quiz_machine.reverse_random_half_in_place(q)
+
         if q.size(0) > 0:
             quiz_machine.save_quizzes(
                 args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q