X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=58e80462609bcda7cb8bc43f88d5dfec9022b8d4;hb=8e23dd068df00df61c690ffa89ecc8cb9db4b32d;hp=c763016bb7e06be9252dda4996b2f74c053732c8;hpb=bf48dc69f7f57ad391481c8917570e35f661cc4a;p=picoclvr.git diff --git a/main.py b/main.py index c763016..58e8046 100755 --- a/main.py +++ b/main.py @@ -133,6 +133,11 @@ parser.add_argument("--expr_result_max", type=int, default=99) parser.add_argument("--expr_input_file", type=str, default=None) +############################## +# World options + +parser.add_argument("--world_vqae_nb_epochs", type=int, default=10) + ###################################################################### args = parser.parse_args() @@ -328,6 +333,7 @@ elif args.task == "world": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + vqae_nb_epochs=args.world_vqae_nb_epochs, device=device, )