assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.problem=="sky":
- problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),
-elif args.problem="wireworld":
- problem=wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+if args.problem == "sky":
+ problem = (sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),)
+elif args.problem == "wireworld":
+ problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4)
else:
raise ValueError