-if args.task == "file":
- assert (
- args.filetask_train_file is not None and args.filetask_test_file is not None
- ), "You have to specify the task train and test files"
- task = tasks.TaskFromFile(
- args.filetask_train_file,
- args.filetask_test_file,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- shuffle=True,
- device=device,
- )
- args.max_percents_of_test_in_train = 0
-
-elif args.task == "byheart":
- task = tasks.SandBox(
- problem=problems.ProblemByHeart(separation=args.byheart_separation),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
- args.max_percents_of_test_in_train = -1
-
-elif args.task == "world":
- task = tasks.World(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
- args.max_percents_of_test_in_train = -1
-
-elif args.task == "learnop":
- task = tasks.SandBox(
- problem=problems.ProblemLearnOperator(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-
-elif args.task == "guessop":
- task = tasks.SandBox(
- problem=problems.ProblemGuessOperator(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-
-elif args.task == "twotargets":
- task = tasks.SandBox(
- problem=problems.ProblemTwoTargets(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-elif args.task == "memory":
- task = tasks.SandBox(
- problem=problems.ProblemMemory(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-elif args.task == "mixing":
- task = tasks.SandBox(
- problem=problems.ProblemMixing(
- hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
- ),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-elif args.task == "addition":
- task = tasks.SandBox(
- problem=problems.ProblemAddition(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-
-elif args.task == "picoclvr":
- task = tasks.PicoCLVR(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- height=args.picoclvr_height,
- width=args.picoclvr_width,
- nb_colors=args.picoclvr_nb_colors,
- logger=log_string,
- device=device,
- pruner_train=picoclvr_pruner_train,
- pruner_eval=picoclvr_pruner_eval,
- )
-
-elif args.task == "mnist":
- task = tasks.MNIST(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- device=device,