X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=55379652a8766a72fa8b20035d4176702f791251;hb=525bd24014786b53638dea78cfb88035a2b99d97;hp=b88847ef75e58bd8aca48df3a706319c768f1087;hpb=504f61114d90b57e1d0faf55a298756da2c8fbfa;p=culture.git diff --git a/main.py b/main.py index b88847e..5537965 100755 --- a/main.py +++ b/main.py @@ -221,10 +221,10 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 -if args.problem=="sky": - problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2), -elif args.problem="wireworld": - problem=wireworld.Wireworld(height=10, width=15, nb_iterations=4) +if args.problem == "sky": + problem = (sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),) +elif args.problem == "wireworld": + problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4) else: raise ValueError