+picoclvr_pruner_eval = (
+ (lambda p: not picoclvr_pruner_horizontal_green(p))
+ if args.picocvlr_prune_properties in {"train+eval", "eval"}
+ else None
+)
+
+######################################################################
+
+if args.physical_batch_size is None:
+ args.physical_batch_size = args.batch_size
+else:
+ assert args.batch_size % args.physical_batch_size == 0
+
+assert args.nb_train_samples % args.batch_size == 0
+assert args.nb_test_samples % args.batch_size == 0
+
+if args.task == "file":
+ assert (
+ args.filetask_train_file is not None and args.filetask_test_file is not None
+ ), "You have to specify the task train and test files"
+ task = tasks.TaskFromFile(
+ args.filetask_train_file,
+ args.filetask_test_file,
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ shuffle=True,
+ device=device,
+ )
+ args.max_percents_of_test_in_train = 0
+
+elif args.task == "byheart":
+ task = tasks.SandBox(
+ problem=problems.ProblemByHeart(separation=args.byheart_separation),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+ args.max_percents_of_test_in_train = -1
+
+elif args.task == "learnop":
+ task = tasks.SandBox(
+ problem=problems.ProblemLearnOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+
+elif args.task == "guessop":
+ task = tasks.SandBox(
+ problem=problems.ProblemGuessOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+
+elif args.task == "twotargets":
+ task = tasks.SandBox(
+ problem=problems.ProblemTwoTargets(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "memory":
+ task = tasks.SandBox(
+ problem=problems.ProblemMemory(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "mixing":
+ task = tasks.SandBox(
+ problem=problems.ProblemMixing(
+ hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+ ),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "addition":
+ task = tasks.SandBox(
+ problem=problems.ProblemAddition(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "picoclvr":
+ task = tasks.PicoCLVR(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ nb_colors=args.picoclvr_nb_colors,
+ logger=log_string,
+ device=device,
+ pruner_train=picoclvr_pruner_train,
+ pruner_eval=picoclvr_pruner_eval,
+ )
+
+elif args.task == "mnist":
+ task = tasks.MNIST(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ device=device,
+ )
+
+elif args.task == "maze":
+ task = tasks.Maze(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ height=args.maze_height,
+ width=args.maze_width,
+ nb_walls=args.maze_nb_walls,
+ device="cpu",
+ )
+
+elif args.task == "snake":
+ task = tasks.Snake(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ height=args.snake_height,
+ width=args.snake_width,
+ nb_colors=args.snake_nb_colors,
+ length=args.snake_length,
+ prompt_length=args.snake_length // 2,
+ device=device,
+ )
+
+elif args.task == "stack":
+ task = tasks.Stack(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ logger=log_string,
+ nb_steps=args.stack_nb_steps,
+ nb_stacks=args.stack_nb_stacks,
+ nb_digits=args.stack_nb_digits,
+ fraction_values_for_train=args.stack_fraction_values_for_train,
+ device=device,
+ )
+
+elif args.task == "expr":
+ task = tasks.Expr(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ nb_variables=args.expr_nb_variables,
+ sequence_length=args.expr_sequence_length,
+ operand_max=args.expr_operand_max,
+ result_max=args.expr_result_max,
+ batch_size=args.physical_batch_size,
+ device=device,
+ )
+
+elif args.task == "rpl":
+ task = tasks.RPL(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ nb_starting_values=args.rpl_nb_starting_values,
+ max_input=args.rpl_max_input,
+ prog_len=args.rpl_prog_len,
+ nb_runs=args.rpl_nb_runs,
+ no_prog=args.rpl_no_prog,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "grid":
+ task = tasks.Grid(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ size=args.grid_size,
+ fraction_play=args.grid_fraction_play,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "qmlp":
+ task = tasks.QMLP(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ result_dir=args.result_dir,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "greed":
+ task = tasks.Greed(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ height=args.greed_height,
+ width=args.greed_width,
+ T=args.greed_T,
+ nb_walls=args.greed_nb_walls,
+ nb_coins=args.greed_nb_coins,
+ logger=log_string,
+ device=device,
+ )
+
+else:
+ raise ValueError(f"Unknown task {args.task}")
+
+######################################################################
+
+log_string(f"device {device}")
+