X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=e0588224f07ecafde8105b00c2b004e0b195e249;hb=31ed8a54992e7701eebd1c3d49bfe8dc20aa65e3;hp=b57c5121116b438c230a5287c2ca16e5621c6626;hpb=9047bd8185ed99c1302d8812551af3d5bd4602cb;p=culture.git diff --git a/main.py b/main.py index b57c512..e058822 100755 --- a/main.py +++ b/main.py @@ -73,6 +73,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--nb_gpts", type=int, default=5) + parser.add_argument("--check", action="store_true", default=False) ###################################################################### @@ -185,9 +187,9 @@ for n in vars(args): ###################################################################### -if args.test: - args.nb_train_samples = 1000 - args.nb_test_samples = 25 +if args.check: + args.nb_train_samples = 500 + args.nb_test_samples = 100 if args.physical_batch_size is None: args.physical_batch_size = args.batch_size @@ -578,7 +580,7 @@ def create_quizzes( task.save_image( new_quizzes[:96], args.result_dir, - f"world_new_{n_epoch:04d}_{model.id:02d}.png", + f"world_quiz_{n_epoch:04d}_{model.id:02d}.png", log_string, ) @@ -587,7 +589,7 @@ def create_quizzes( models = [] -for k in range(5): +for k in range(args.nb_gpts): model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -614,7 +616,7 @@ accuracy_to_make_quizzes = 0.975 nb_new_quizzes_for_train = 1000 nb_new_quizzes_for_test = 100 -if args.test: +if args.check: accuracy_to_make_quizzes = 0.0 nb_new_quizzes_for_train = 10 nb_new_quizzes_for_test = 10