- # problem,
- # problems.ProblemAddition(zero_padded=False, inverted_result=False),
- # problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
- problems.ProblemTwoTargets(len_total=16, len_targets=4),
+ problem=problems.ProblemTwoTargets(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "memory":
+ task = tasks.SandBox(
+ problem=problems.ProblemMemory(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "mixing":
+ task = tasks.SandBox(
+ problem=problems.ProblemMixing(
+ hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+ ),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "addition":
+ task = tasks.SandBox(
+ problem=problems.ProblemAddition(),