- def print_worlds(
- self, src=None, comments=[], width=None, printer=print, ansi_term=False
- ):
- if src is None:
- src = list(self.worlds)
-
- height = max([x.size(0) if torch.is_tensor(x) else 1 for x in src])
-
- def tile(n):
- n = n.item()
- if n in self.id2tile:
- return self.id2tile[n]
- else:
- return "?"
-
- for k in range(height):
-
- def f(x):
- if torch.is_tensor(x):
- if x.dim() == 0:
- x = str(x.item())
- return " " * len(x) if k < height - 1 else x
- else:
- s = "".join([tile(n) for n in x[k]])
- if ansi_term:
- for u, c in [("#", 40), ("$", 31), ("@", 32)] + [
- (x, 36) for x in "aAbBcC"
- ]:
- s = s.replace(u, f"\u001b[{c}m{u}\u001b[0m")
- return s
- else:
- return " " * len(x) if k < height - 1 else x
-
- printer("|".join([f(x) for x in src]))
-