X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=flatland.git;a=blobdiff_plain;f=test.py;fp=test.py;h=2853309befc2e6850c689af3a455df001b97cecc;hp=c6b6c48c785a3b2a28a333c77f64af242906db79;hb=1f91ec6f67da83525115f49dcc7d535ff2e71ef0;hpb=e0f96aaef35ffaf34b912c7fa1473cb67b7a3dae diff --git a/test.py b/test.py index c6b6c48..2853309 100755 --- a/test.py +++ b/test.py @@ -8,7 +8,7 @@ from _ext import flatland ###################################################################### -def sequences_to_image(x): +def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)): from PIL import Image nb_sequences = x.size(0) @@ -21,8 +21,6 @@ def sequences_to_image(x): height = x.size(3) width = x.size(4) - gap = 1 - gap_color = (0, 128, 255) result = torch.ByteTensor(nb_channels, gap + nb_sequences * (height + gap), @@ -34,7 +32,9 @@ def sequences_to_image(x): for s in range(0, nb_sequences): for i in range(0, nb_images_per_sequences): - result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i]) + result.narrow(1, gap + s * (height + gap), height) \ + .narrow(2, gap + i * (width + gap), width) \ + .copy_(x[s][i]) result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy() @@ -42,6 +42,6 @@ def sequences_to_image(x): ###################################################################### -x = flatland.generate_sequence(5, 3, 128, 96) +x = flatland.generate_sequence(1, 3, 80, 80, True, True) -sequences_to_image(x).save('sequences.png') +sequences_to_image(x, gap = 2, gap_color = (0, 0, 0)).save('sequences.png')