Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index d7fb3d1..5537965 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
 
 import ffutils
 import mygpt
-import sky, quizz_machine
+import sky, wireworld, quizz_machine
 
 # world quizzes vs. culture quizzes
 
 
 # world quizzes vs. culture quizzes
 
@@ -37,7 +37,7 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+parser.add_argument("--log_filename", type=str, default="train.log")
 
 parser.add_argument("--result_dir", type=str, default=None)
 
 
 parser.add_argument("--result_dir", type=str, default=None)
 
@@ -79,6 +79,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--problem", type=str, default="sky")
+
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--nb_models_for_generation", type=int, default=1)
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--nb_models_for_generation", type=int, default=1)
@@ -219,8 +221,15 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
+if args.problem == "sky":
+    problem = (sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),)
+elif args.problem == "wireworld":
+    problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+else:
+    raise ValueError
+
 quizz_machine = quizz_machine.QuizzMachine(
 quizz_machine = quizz_machine.QuizzMachine(
-    problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),
+    problem=problem,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
@@ -343,7 +352,6 @@ def run_tests(model, quizz_machine, deterministic_synthesis):
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
-            logger=log_string,
             deterministic_synthesis=deterministic_synthesis,
         )
 
             deterministic_synthesis=deterministic_synthesis,
         )
 
@@ -397,7 +405,6 @@ def create_c_quizzes(
             min_ave_seq_logproba=min_ave_seq_logproba,
             n_epoch=n_epoch,
             result_dir=args.result_dir,
             min_ave_seq_logproba=min_ave_seq_logproba,
             n_epoch=n_epoch,
             result_dir=args.result_dir,
-            logger=log_string,
         )
 
         sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
         )
 
         sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
@@ -487,7 +494,8 @@ for n_epoch in range(args.nb_epochs):
 
     a = [(model.id, float(model.main_test_accuracy)) for model in models]
     a.sort(key=lambda p: p[0])
 
     a = [(model.id, float(model.main_test_accuracy)) for model in models]
     a.sort(key=lambda p: p[0])
-    log_string(f"current accuracies {a}")
+    s = " ".join([f"{p[1]*100:.02f}%" for p in a])
+    log_string(f"current accuracies {s}")
 
     # select the model with lowest accuracy
     models.sort(key=lambda model: model.main_test_accuracy)
 
     # select the model with lowest accuracy
     models.sort(key=lambda model: model.main_test_accuracy)