Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index 05c3557..402e6e5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,7 +12,8 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
 from torch.nn import functional as F
 
 import ffutils
-import mygpt, quizz_machine
+import mygpt
+import sky, quizz_machine
 
 # world quizzes vs. culture quizzes
 
 
 # world quizzes vs. culture quizzes
 
@@ -210,6 +211,7 @@ assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 quizz_machine = quizz_machine.QuizzMachine(
 assert args.nb_test_samples % args.batch_size == 0
 
 quizz_machine = quizz_machine.QuizzMachine(
+    sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2),
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
@@ -390,11 +392,10 @@ def create_c_quizzes(
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    quizz_machine.save_quizzes(
+    quizz_machine.problem.save_quizzes(
         new_c_quizzes[:72],
         args.result_dir,
         f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
         new_c_quizzes[:72],
         args.result_dir,
         f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
-        log_string,
     )
 
     return sum_logits / sum_nb_c_quizzes
     )
 
     return sum_logits / sum_nb_c_quizzes