-if args.task == "sandbox":
- if args.sandbox_level == 0:
- problem = problems.ProblemLevel0(
- nb_sentences=args.sandbox_levels_nb_items,
- len_prompt=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- elif args.sandbox_level == 1:
- problem = problems.ProblemLevel1(
- nb_operators=args.sandbox_levels_nb_items,
- len_source=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- elif args.sandbox_level == 2:
- problem = problems.ProblemLevel2(
- len_source=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- else:
- raise ValueError(f"Unknown sandbox level {args.sandbox_level}")
-
- task = tasks.SandBox(
- # problem,
- # problems.ProblemAddition(zero_padded=False, inverted_result=False),
- # problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
- problems.ProblemTwoTargets(len_total=12, len_targets=4),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
- logger=log_string,
- device=device,
- )