index1 = torch.randint(src.size(2), (src.size(3), src.size(1), src.size(3)))
index2 = torch.randint(src.size(3), (src.size(1),))
index1 = torch.randint(src.size(2), (src.size(3), src.size(1), src.size(3)))
index2 = torch.randint(src.size(3), (src.size(1),))