nb_test_samples,
batch_size,
size,
+ nb_shapes,
+ nb_colors,
logger=None,
device=torch.device("cpu"),
):
self.device = device
self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(size=size)
+ self.grid_factory = grid.GridFactory(
+ size=size, nb_shapes=nb_shapes, nb_colors=nb_colors
+ )
if logger is not None:
logger(