X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=305bd3cfb75a351d1b8513dcf9ce136c7a55844f;hb=2192d72289bbf2cd069f67d3e93daf7934f886af;hp=c763016bb7e06be9252dda4996b2f74c053732c8;hpb=bf48dc69f7f57ad391481c8917570e35f661cc4a;p=picoclvr.git diff --git a/main.py b/main.py index c763016..305bd3c 100755 --- a/main.py +++ b/main.py @@ -133,6 +133,11 @@ parser.add_argument("--expr_result_max", type=int, default=99) parser.add_argument("--expr_input_file", type=str, default=None) +############################## +# World options + +parser.add_argument("--world_vqae_nb_epochs", type=int, default=25) + ###################################################################### args = parser.parse_args() @@ -182,9 +187,9 @@ default_args = { "nb_test_samples": 10000, }, "world": { - "nb_epochs": 5, + "nb_epochs": 10, "batch_size": 25, - "nb_train_samples": 10000, + "nb_train_samples": 125000, "nb_test_samples": 1000, }, } @@ -328,6 +333,8 @@ elif args.task == "world": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + vqae_nb_epochs=args.world_vqae_nb_epochs, + logger=log_string, device=device, )