from torch.nn import functional as F
import ffutils
-import mygpt, quizz_machine
+import mygpt
+import sky, quizz_machine
# world quizzes vs. culture quizzes
assert args.nb_test_samples % args.batch_size == 0
quizz_machine = quizz_machine.QuizzMachine(
+ sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
- quizz_machine.save_quizzes(
+ quizz_machine.problem.save_quizzes(
new_c_quizzes[:72],
args.result_dir,
f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",