Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 58e8046..305bd3c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -136,7 +136,7 @@ parser.add_argument("--expr_input_file", type=str, default=None)
 ##############################
 # World options
 
-parser.add_argument("--world_vqae_nb_epochs", type=int, default=10)
+parser.add_argument("--world_vqae_nb_epochs", type=int, default=25)
 
 ######################################################################
 
@@ -187,9 +187,9 @@ default_args = {
         "nb_test_samples": 10000,
     },
     "world": {
-        "nb_epochs": 5,
+        "nb_epochs": 10,
         "batch_size": 25,
-        "nb_train_samples": 10000,
+        "nb_train_samples": 125000,
         "nb_test_samples": 1000,
     },
 }
@@ -334,6 +334,7 @@ elif args.task == "world":
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
         vqae_nb_epochs=args.world_vqae_nb_epochs,
+        logger=log_string,
         device=device,
     )