"--task",
type=str,
default="picoclvr",
- help="picoclvr, mnist, maze, snake, stack, expr",
+ help="picoclvr, mnist, maze, snake, stack, expr, world",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--snake_length", type=int, default=200)
##############################
-# Snake options
+# Stack options
parser.add_argument("--stack_nb_steps", type=int, default=100)
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
+ "world": {
+ "nb_epochs": 5,
+ "batch_size": 25,
+ "nb_train_samples": 10000,
+ "nb_test_samples": 1000,
+ },
}
if args.task in default_args:
device=device,
)
+elif args.task == "world":
+ task = tasks.World(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ device=device,
+ )
+
else:
raise ValueError(f"Unknown task {args.task}")