+
+u = torch.linspace(0, 1, 101)
+v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
+v = model(v).reshape(101, -1)
+mean = v.mean(1)
+std = v.std(1)
+
+ax.fill_between(
+ u.numpy(),
+ (mean - std).detach().numpy(),
+ (mean + std).detach().numpy(),
+ color="#e0e0e0",
+)
+ax.plot(u.numpy(), mean.detach().numpy(), color="red")