parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--nb_gpts", type=int, default=5)
+
+parser.add_argument("--check", action="store_true", default=False)
+
######################################################################
args = parser.parse_args()
######################################################################
+if args.check:
+ args.nb_train_samples = 500
+ args.nb_test_samples = 100
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
other_models=other_models,
)
+ print(nb_correct)
+
to_keep = new_quizzes[nb_correct == len(other_models) - 1]
log_string(f"keep {to_keep.size(0)} quizzes")
kept.append(to_keep)
task.save_image(
new_quizzes[:96],
args.result_dir,
- f"world_new_{n_epoch:04d}_{model.id:02d}.png",
+ f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
log_string,
)
models = []
-for k in range(5):
+for k in range(args.nb_gpts):
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
######################################################################
accuracy_to_make_quizzes = 0.975
+nb_new_quizzes_for_train = 1000
+nb_new_quizzes_for_test = 100
+
+if args.check:
+ accuracy_to_make_quizzes = 0.0
+ nb_new_quizzes_for_train = 10
+ nb_new_quizzes_for_test = 10
for n_epoch in range(args.nb_epochs):
# select the model with lowest accuracy
model,
other_models,
task,
- nb_for_train=1000,
- nb_for_test=100,
+ nb_for_train=nb_new_quizzes_for_train,
+ nb_for_test=nb_new_quizzes_for_test,
)