- ax.imshow(test_A[k], cmap = 'binary', interpolation='nearest')
- delta = 0.
- ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
- ax.scatter(torch.full((test_bx.size(1),), delta), test_bx[k, :, 0], color = 'black', marker = 's', clip_on=False)
- ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
- ax.scatter(torch.full((test_tr.size(1),), delta), test_tr[k, :, 0], color = 'black', marker = '^', clip_on=False)
+ ax.imshow(test_A[k], cmap="binary", interpolation="nearest")
+ delta = 0.0
+ ax.scatter(
+ test_bx[k, :, 0],
+ torch.full((test_bx.size(1),), delta),
+ color="black",
+ marker="s",
+ clip_on=False,
+ )
+ ax.scatter(
+ torch.full((test_bx.size(1),), delta),
+ test_bx[k, :, 0],
+ color="black",
+ marker="s",
+ clip_on=False,
+ )
+ ax.scatter(
+ test_tr[k, :, 0],
+ torch.full((test_tr.size(1),), delta),
+ color="black",
+ marker="^",
+ clip_on=False,
+ )
+ ax.scatter(
+ torch.full((test_tr.size(1),), delta),
+ test_tr[k, :, 0],
+ color="black",
+ marker="^",
+ clip_on=False,
+ )