Update.
[flatland.git] / test.py
diff --git a/test.py b/test.py
index 314e03d..2853309 100755 (executable)
--- a/test.py
+++ b/test.py
@@ -4,9 +4,11 @@ import torch
 import torchvision
 from torchvision import datasets
 
+from _ext import flatland
+
 ######################################################################
 
-def sequences_to_image(x):
+def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)):
     from PIL import Image
 
     nb_sequences = x.size(0)
@@ -19,8 +21,6 @@ def sequences_to_image(x):
 
     height = x.size(3)
     width = x.size(4)
-    gap = 1
-    gap_color = (0, 128, 255)
 
     result = torch.ByteTensor(nb_channels,
                               gap + nb_sequences * (height + gap),
@@ -32,7 +32,9 @@ def sequences_to_image(x):
 
     for s in range(0, nb_sequences):
         for i in range(0, nb_images_per_sequences):
-            result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i])
+            result.narrow(1, gap + s * (height + gap), height) \
+                  .narrow(2, gap + i * (width + gap), width) \
+                  .copy_(x[s][i])
 
     result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy()
 
@@ -40,10 +42,6 @@ def sequences_to_image(x):
 
 ######################################################################
 
-from _ext import flatland
-
-x = torch.ByteTensor()
-
-flatland.generate_sequence(10, x)
+x = flatland.generate_sequence(1, 3, 80, 80, True, True)
 
-sequences_to_image(x).save('sequences.png')
+sequences_to_image(x, gap = 2, gap_color = (0, 0, 0)).save('sequences.png')