X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=flatland.git;a=blobdiff_plain;f=test.py;h=c6b6c48c785a3b2a28a333c77f64af242906db79;hp=de408aa994dc3224e1949900e89567bf9af48c66;hb=4aed0ce274b7c0e379651c28e439375c821c047a;hpb=2cd32038873961c8ff3861efb218fad75fbcbf69 diff --git a/test.py b/test.py index de408aa..c6b6c48 100755 --- a/test.py +++ b/test.py @@ -4,17 +4,44 @@ import torch import torchvision from torchvision import datasets -from _ext import mylib +from _ext import flatland -x = torch.ByteTensor(4, 5).fill_(0) +###################################################################### -print(x.size()) +def sequences_to_image(x): + from PIL import Image -mylib.generate_sequence(8, x) + nb_sequences = x.size(0) + nb_images_per_sequences = x.size(1) + nb_channels = 3 -print(x.size()) + if x.size(2) != nb_channels: + print('Can only handle 3 channel tensors.') + exit(1) -x = x.float().sub_(128).div_(128) + height = x.size(3) + width = x.size(4) + gap = 1 + gap_color = (0, 128, 255) -for s in range(0, x.size(0)): - torchvision.utils.save_image(x[s], 'example_' + str(s) + '.png') + result = torch.ByteTensor(nb_channels, + gap + nb_sequences * (height + gap), + gap + nb_images_per_sequences * (width + gap)) + + result[0].fill_(gap_color[0]) + result[1].fill_(gap_color[1]) + result[2].fill_(gap_color[2]) + + for s in range(0, nb_sequences): + for i in range(0, nb_images_per_sequences): + result.narrow(1, gap + s * (height + gap), height).narrow(2, gap + i * (width + gap), width).copy_(x[s][i]) + + result_numpy = result.cpu().byte().transpose(0, 2).transpose(0, 1).numpy() + + return Image.fromarray(result_numpy, 'RGB') + +###################################################################### + +x = flatland.generate_sequence(5, 3, 128, 96) + +sequences_to_image(x).save('sequences.png')