+ else:
+ raise ValueError(f"Unknown sandbox level {args.sandbox_level}")
+
+ task = tasks.SandBox(
+ # problem,
+ # problems.ProblemAddition(zero_padded=False, inverted_result=False),
+ problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "picoclvr":
+ task = tasks.PicoCLVR(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ nb_colors=args.picoclvr_nb_colors,
+ logger=log_string,
+ device=device,
+ pruner_train=picoclvr_pruner_train,
+ pruner_eval=picoclvr_pruner_eval,
+ )
+
+elif args.task == "mnist":
+ task = tasks.MNIST(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ device=device,
+ )
+
+elif args.task == "maze":
+ task = tasks.Maze(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.maze_height,
+ width=args.maze_width,
+ nb_walls=args.maze_nb_walls,
+ device=device,
+ )
+
+elif args.task == "snake":
+ task = tasks.Snake(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.snake_height,
+ width=args.snake_width,
+ nb_colors=args.snake_nb_colors,
+ length=args.snake_length,
+ prompt_length=args.snake_length // 2,
+ device=device,
+ )
+
+elif args.task == "stack":
+ task = tasks.Stack(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ nb_steps=args.stack_nb_steps,
+ nb_stacks=args.stack_nb_stacks,
+ nb_digits=args.stack_nb_digits,
+ fraction_values_for_train=args.stack_fraction_values_for_train,
+ device=device,
+ )
+
+elif args.task == "expr":
+ task = tasks.Expr(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ nb_variables=args.expr_nb_variables,
+ sequence_length=args.expr_sequence_length,
+ operand_max=args.expr_operand_max,
+ result_max=args.expr_result_max,
+ batch_size=args.batch_size,
+ device=device,
+ )
+
+elif args.task == "rpl":
+ task = tasks.RPL(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ nb_starting_values=args.rpl_nb_starting_values,
+ max_input=args.rpl_max_input,
+ prog_len=args.rpl_prog_len,
+ nb_runs=args.rpl_nb_runs,
+ no_prog=args.rpl_no_prog,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "world":
+ task = tasks.World(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ vqae_nb_epochs=args.world_vqae_nb_epochs,
+ logger=log_string,
+ device=device,
+ )