X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=524715a33bf152a7b364d43d0b3d44c17bfe727a;hb=ff043757ea7d5d992a3d1fc4c435c1422997b1af;hp=05c3557fe0e8158126506aad654a094645044b0a;hpb=336130cc923761658029a0af9d5862d59405d47a;p=culture.git diff --git a/main.py b/main.py index 05c3557..524715a 100755 --- a/main.py +++ b/main.py @@ -12,7 +12,8 @@ from torch import nn from torch.nn import functional as F import ffutils -import mygpt, quizz_machine +import mygpt +import sky, quizz_machine # world quizzes vs. culture quizzes @@ -210,6 +211,7 @@ assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 quizz_machine = quizz_machine.QuizzMachine( + sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, @@ -390,7 +392,7 @@ def create_c_quizzes( quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - quizz_machine.save_quizzes( + quizz_machine.problem.save_quizzes( new_c_quizzes[:72], args.result_dir, f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",