bf51360fbcf7a014071d3c0c64e746734767601d
[flatland.git] / test.py
1 #!/usr/bin/env python-for-pytorch
2
3 import torch
4 import torchvision
5 from torchvision import datasets
6
7 ######################################################################
8
9 def sequences_to_image(x):
10     from PIL import Image
11
12     nb_sequences = x.size(0)
13     nb_images_per_sequences = x.size(1)
14     nb_channels = 3
15
16     if x.size(2) != nb_channels:
17         print('Can only handle 3 channel tensors.')
18         exit(1)
19
20     height = x.size(3)
21     width = x.size(4)
22     gap = 1
23     gap_color = (0, 128, 255)
24
25     result = torch.ByteTensor(nb_channels,
26                               gap + nb_sequences * (height + gap),
27                               gap + nb_images_per_sequences * (width + gap))
28
29     result[0].fill_(gap_color[0])
30     result[1].fill_(gap_color[1])
31     result[2].fill_(gap_color[2])
32
33     for s in range(0, nb_sequences):
34         for i in range(0, nb_images_per_sequences):
35             result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i])
36
37     result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
38
39     return Image.fromarray(result_numpy, 'RGB')
40
41 ######################################################################
42
43 from _ext import mylib
44
45 x = torch.ByteTensor()
46
47 mylib.generate_sequence(10, x)
48
49 sequences_to_image(x).save('sequences.png')