type = str, default = 'train.log')
parser.add_argument('--download',
- type = bool, default = False)
+ action='store_true', default = False)
parser.add_argument('--seed',
type = int, default = 0)
type = float, default = 0.1)
parser.add_argument('--synthesis_sampling',
- type = bool, default = True)
+ action='store_true', default = True)
parser.add_argument('--checkpoint_name',
type = str, default = 'checkpoint.pth')
+parser.add_argument('--picoclvr_many_colors',
+ action='store_true', default = False)
+
######################################################################
args = parser.parse_args()
elif args.data == 'mnist':
task = TaskMNIST(batch_size = args.batch_size, device = device)
elif args.data == 'picoclvr':
- task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
+ task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
else:
raise ValueError(f'Unknown dataset {args.data}.')