parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--sky_height", type=int, default=6)
+
+parser.add_argument("--sky_width", type=int, default=8)
+
+parser.add_argument("--sky_nb_birds", type=int, default=3)
+
+parser.add_argument("--sky_nb_iterations", type=int, default=2)
+
+parser.add_argument("--sky_speed", type=int, default=3)
+
######################################################################
args = parser.parse_args()
assert args.nb_test_samples % args.batch_size == 0
if args.problem == "sky":
- problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3)
+ problem = sky.Sky(
+ height=args.sky_height,
+ width=args.sky_width,
+ nb_birds=args.sky_nb_birds,
+ nb_iterations=args.sky_nb_iterations,
+ speed=args.sky_speed,
+ )
elif args.problem == "wireworld":
problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
else: