seq_length = 100
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
- st = torch.arange(seq_length).float()
+ st = torch.arange(seq_length, device = device).float()
st = st[None, :, None]
tr = tr[:, None, :, :]
bx = bx[:, None, :, :]
x = torch.cat((xtr, xbx), 2)
- # u = x.sign()
u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
collisions = (u.sum(2) > 1).max(1).values
# Position / height / width
- tr = torch.empty(nb, 2, 3)
+ tr = torch.empty(nb, 2, 3, device = device)
tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
- bx = torch.empty(nb, 2, 3)
+ bx = torch.empty(nb, 2, 3, device = device)
bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
delta = -1.
if tr is not None:
- ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
+ ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False)
if bx is not None:
- ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
+ ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False)
fig.savefig(filename, bbox_inches='tight')