X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=flatland.git;a=blobdiff_plain;f=test.py;h=0a065fa6887dbdce00d9d82281a6922102d26d4d;hp=314e03dc3d191859f15c4b2fd3e43c0a0456877b;hb=6c81a4b8ff7547e793d3f8bbd41ecdfdb73cbbbe;hpb=5d4e9eaeec9263692d39ca840e498a5f1d818eaa diff --git a/test.py b/test.py index 314e03d..0a065fa 100755 --- a/test.py +++ b/test.py @@ -4,9 +4,11 @@ import torch import torchvision from torchvision import datasets +from _ext import flatland + ###################################################################### -def sequences_to_image(x): +def sequences_to_image(x, gap=1, gap_color = (0, 128, 255)): from PIL import Image nb_sequences = x.size(0) @@ -19,8 +21,6 @@ def sequences_to_image(x): height = x.size(3) width = x.size(4) - gap = 1 - gap_color = (0, 128, 255) result = torch.ByteTensor(nb_channels, gap + nb_sequences * (height + gap), @@ -32,7 +32,9 @@ def sequences_to_image(x): for s in range(0, nb_sequences): for i in range(0, nb_images_per_sequences): - result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i]) + result.narrow(1, gap + s * (height + gap), height) \ + .narrow(2, gap + i * (width + gap), width) \ + .copy_(x[s][i]) result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy() @@ -40,10 +42,6 @@ def sequences_to_image(x): ###################################################################### -from _ext import flatland - -x = torch.ByteTensor() - -flatland.generate_sequence(10, x) +x = flatland.generate_sequence(10, 6, 80, 80, True, True) -sequences_to_image(x).save('sequences.png') +sequences_to_image(x, gap = 2, gap_color = (0, 0, 0)).save('sequences.png')