Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 4ff50d7..9c3d7f1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,8 +12,10 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
+
 import mygpt
 import sky, grids, quiz_machine
+from problem import MultiThreadProblem
 
 # world quizzes vs. culture quizzes
 
@@ -76,6 +78,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--problem", type=str, default="grids")
 
+parser.add_argument("--multi_thread_problem", action="store_true", default=False)
+
 parser.add_argument("--nb_gpts", type=int, default=5)
 
 parser.add_argument("--min_to_validate", type=int, default=None)
@@ -244,6 +248,9 @@ elif args.problem == "grids":
 else:
     raise ValueError
 
+if args.multi_thread_problem:
+    problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000)
+
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
     nb_train_samples=args.nb_train_samples,