# Pytorch
import torch
+import torchvision
from torch import optim
from torch import FloatTensor as Tensor
parser.add_argument('--log_file',
type = str, default = 'default.log')
+parser.add_argument('--nb_exemplar_vignettes',
+ type = int, default = -1)
+
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
)
self.last_t = t
+def save_examplar_vignettes(data_set, nb, name):
+ n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+ for k in range(0, nb):
+ b = n[k] // data_set.batch_size
+ m = n[k] % data_set.batch_size
+ i, t = data_set.get_batch(b)
+ i = i[m].float()
+ i.sub_(i.min())
+ i.div_(i.max())
+ if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+ patchwork[k].copy_(i)
+
+ torchvision.utils.save_image(patchwork, name)
+
######################################################################
if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
train_set.nb_samples / (time.time() - t))
)
+ if args.nb_exemplar_vignettes > 0:
+ save_examplar_vignettes(train_set, args.nb_exemplar_vignettes,
+ 'examplar_{:d}.png'.format(problem_number))
+
if args.validation_error_threshold > 0.0:
validation_set = VignetteSet(problem_number,
args.nb_validation_samples, args.batch_size,