X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=f4e4f5c2c11e27d9be2666b81fc82d6e02f09b40;hb=8d9cd6a2c09da2105ca17b04df94fcf84e8de954;hp=496a6034b857d303baa1a32ab8bb48fd68bf84eb;hpb=bd0903943d9b44e0de58835c63c7a59d72a65c94;p=picoclvr.git diff --git a/main.py b/main.py index 496a603..f4e4f5c 100755 --- a/main.py +++ b/main.py @@ -164,6 +164,8 @@ parser.add_argument("--expr_input_file", type=str, default=None) parser.add_argument("--mixing_hard", action="store_true", default=False) +parser.add_argument("--mixing_deterministic_start", action="store_true", default=False) + ###################################################################### args = parser.parse_args() @@ -416,7 +418,9 @@ elif args.task == "twotargets": elif args.task == "mixing": task = tasks.SandBox( - problem=problems.ProblemMixing(hard=args.mixing_hard), + problem=problems.ProblemMixing( + hard=args.mixing_hard, random_start=not args.mixing_deterministic_start + ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size,