Update.
[flatland.git] / test.py
diff --git a/test.py b/test.py
index d6d2df1..c6b6c48 100755 (executable)
--- a/test.py
+++ b/test.py
@@ -4,16 +4,44 @@ import torch
 import torchvision
 from torchvision import datasets
 
-from _ext import mylib
+from _ext import flatland
 
-x = torch.ByteTensor(4, 5).fill_(0)
+######################################################################
 
-print(x.size())
+def sequences_to_image(x):
+    from PIL import Image
 
-mylib.generate_sequence(x)
+    nb_sequences = x.size(0)
+    nb_images_per_sequences = x.size(1)
+    nb_channels = 3
 
-print(x.size())
+    if x.size(2) != nb_channels:
+        print('Can only handle 3 channel tensors.')
+        exit(1)
 
-x = x.float().sub_(128).div_(128)
+    height = x.size(3)
+    width = x.size(4)
+    gap = 1
+    gap_color = (0, 128, 255)
 
-torchvision.utils.save_image(x[0], 'example.png')
+    result = torch.ByteTensor(nb_channels,
+                              gap + nb_sequences * (height + gap),
+                              gap + nb_images_per_sequences * (width + gap))
+
+    result[0].fill_(gap_color[0])
+    result[1].fill_(gap_color[1])
+    result[2].fill_(gap_color[2])
+
+    for s in range(0, nb_sequences):
+        for i in range(0, nb_images_per_sequences):
+            result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i])
+
+    result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
+
+    return Image.fromarray(result_numpy, 'RGB')
+
+######################################################################
+
+x = flatland.generate_sequence(5, 3, 128, 96)
+
+sequences_to_image(x).save('sequences.png')