X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=727b196b3a2f008854fb314389254589ab29d715;hb=a1ae050705970007f965d2586c53e9bd262e46aa;hp=4777a11676447c4137683ec988bd011a0ad69d81;hpb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 4777a11..727b196 100755 --- a/tasks.py +++ b/tasks.py @@ -1473,6 +1473,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1482,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger(