import torchvision
from torchvision import datasets
-from _ext import mylib
+from _ext import flatland
-x = torch.ByteTensor(4, 5).fill_(0)
+######################################################################
-print(x.size())
+def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)):
+ from PIL import Image
-mylib.generate_sequence(8, x)
+ nb_sequences = x.size(0)
+ nb_images_per_sequences = x.size(1)
+ nb_channels = 3
-print(x.size())
+ if x.size(2) != nb_channels:
+ print('Can only handle 3 channel tensors.')
+ exit(1)
-x = x.float().sub_(128).div_(128)
+ height = x.size(3)
+ width = x.size(4)
-for s in range(0, x.size(0)):
- torchvision.utils.save_image(x[s], 'example_' + str(s) + '.png')
+ result = torch.ByteTensor(nb_channels,
+ gap + nb_sequences * (height + gap),
+ gap + nb_images_per_sequences * (width + gap))
+
+ result[0].fill_(gap_color[0])
+ result[1].fill_(gap_color[1])
+ result[2].fill_(gap_color[2])
+
+ for s in range(0, nb_sequences):
+ for i in range(0, nb_images_per_sequences):
+ result.narrow(1, gap + s * (height + gap), height) \
+ .narrow(2, gap + i * (width + gap), width) \
+ .copy_(x[s][i])
+
+ result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
+
+ return Image.fromarray(result_numpy, 'RGB')
+
+######################################################################
+
+x = flatland.generate_sequence(1, 3, 80, 80, True, True)
+
+sequences_to_image(x, gap = 2, gap_color = (0, 0, 0)).save('sequences.png')