- def policy_batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- policies = self.train_policies if split == "train" else self.test_policies
- input = input[:, : self.height * self.width]
- policies = policies * (input != maze.v_wall)[:, None]
-
- if nb_to_use > 0:
- input = input[:nb_to_use]
- policies = policies[:nb_to_use]
-
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- zip(input.split(self.batch_size), policies.split(self.batch_size)),
- dynamic_ncols=True,
- desc=desc,
- ):
- yield batch
-