X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=704dff5b95a918637e8cfc0282796322cb706fea;hb=6681907dcc86bf6e159925814d419f522e0e3300;hp=00e19ac78f1695265286ce4e1160423468f753bd;hpb=6f61f9438799d65c980726e28546f8775bf83a60;p=picoclvr.git diff --git a/main.py b/main.py index 00e19ac..704dff5 100755 --- a/main.py +++ b/main.py @@ -99,6 +99,11 @@ parser.add_argument("--rpl_nb_runs", type=int, default=5) parser.add_argument("--rpl_no_prog", action="store_true", default=False) +############################## +# grid options + +parser.add_argument("--grid_size", type=int, default=6) + ############################## # picoclvr options @@ -517,8 +522,7 @@ elif args.task == "grid": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.picoclvr_height, - width=args.picoclvr_width, + size=args.grid_size, logger=log_string, device=device, )