X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=305bd3cfb75a351d1b8513dcf9ce136c7a55844f;hb=2192d72289bbf2cd069f67d3e93daf7934f886af;hp=58e80462609bcda7cb8bc43f88d5dfec9022b8d4;hpb=8e23dd068df00df61c690ffa89ecc8cb9db4b32d;p=picoclvr.git diff --git a/main.py b/main.py index 58e8046..305bd3c 100755 --- a/main.py +++ b/main.py @@ -136,7 +136,7 @@ parser.add_argument("--expr_input_file", type=str, default=None) ############################## # World options -parser.add_argument("--world_vqae_nb_epochs", type=int, default=10) +parser.add_argument("--world_vqae_nb_epochs", type=int, default=25) ###################################################################### @@ -187,9 +187,9 @@ default_args = { "nb_test_samples": 10000, }, "world": { - "nb_epochs": 5, + "nb_epochs": 10, "batch_size": 25, - "nb_train_samples": 10000, + "nb_train_samples": 125000, "nb_test_samples": 1000, }, } @@ -334,6 +334,7 @@ elif args.task == "world": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, vqae_nb_epochs=args.world_vqae_nb_epochs, + logger=log_string, device=device, )