-class Task:
- def batches(self, split="train", nb_to_use=-1, desc=None):
- pass
-
- def vocabulary_size(self):
- pass
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- pass
-
-
-######################################################################
-
-import sky
-
-
-class QuizzMachine(Task):
- def save_image(self, input, result_dir, filename, logger):
- img = sky.seq2img(input.to("cpu"), self.height, self.width)
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
- logger(f"wrote {image_name}")
-
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- self.save_image(input, result_dir, filename_prefix + ".png", logger)
-