X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=main.py;h=402e6e5db8d98f87dec1efdcfef60e230a8a9ea5;hb=a2346746c9b417eaf97aad87ed31dea92c3bb887;hp=05c3557fe0e8158126506aad654a094645044b0a;hpb=336130cc923761658029a0af9d5862d59405d47a;p=culture.git diff --git a/main.py b/main.py index 05c3557..402e6e5 100755 --- a/main.py +++ b/main.py @@ -12,7 +12,8 @@ from torch import nn from torch.nn import functional as F import ffutils -import mygpt, quizz_machine +import mygpt +import sky, quizz_machine # world quizzes vs. culture quizzes @@ -210,6 +211,7 @@ assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 quizz_machine = quizz_machine.QuizzMachine( + sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, @@ -390,11 +392,10 @@ def create_c_quizzes( quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - quizz_machine.save_quizzes( + quizz_machine.problem.save_quizzes( new_c_quizzes[:72], args.result_dir, f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}", - log_string, ) return sum_logits / sum_nb_c_quizzes