+ def save_image(self, input, result_dir, filename, logger):
+ img = self.seq2img(input.to("cpu"))
+ image_name = os.path.join(result_dir, filename)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+ logger(f"wrote {image_name}")
+
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+