+class vignette_logger():
+ def __init__(self, delay_min = 60):
+ self.start_t = time.time()
+ self.last_t = self.start_t
+ self.delay_min = delay_min
+
+ def __call__(self, n, m):
+ t = time.time()
+ if t > self.last_t + self.delay_min:
+ dt = (t - self.start_t) / m
+ log_string('sample_generation {:d} / {:d}'.format(
+ m,
+ n), ' [ETA ' + time.ctime(time.time() + dt * (n - m)) + ']'
+ )
+ self.last_t = t
+
+def save_examplar_vignettes(data_set, nb, name):
+ n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb)
+
+ for k in range(0, nb):
+ b = n[k] // data_set.batch_size
+ m = n[k] % data_set.batch_size
+ i, t = data_set.get_batch(b)
+ i = i[m].float()
+ i.sub_(i.min())
+ i.div_(i.max())
+ if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2))
+ patchwork[k].copy_(i)
+
+ torchvision.utils.save_image(patchwork, name)
+