self.batch_size = batch_size
self.device = device
- self.height = 7
- self.width = 9
+ self.height = 6
+ self.width = 8
- self.train_input = world.generate(
+ self.train_input = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_input = world.generate(
+ self.test_input = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)