- img = world.sample2img(result.to("cpu"), self.height, self.width)
- image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
- logger(f"wrote {image_name}")
+ self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger)
+
+ def store_new_problems(self, new_problems):
+ nb_current = self.train_input.size(0)
+ nb_new = new_problems.size(0)
+ if nb_new >= nb_current:
+ self.train_input[...] = new_problems[:nb_current]
+ else:
+ nb_kept = nb_current - nb_new
+ self.train_input[:nb_kept] = self.train_input[-nb_kept:].clone()
+ self.train_input[nb_kept:] = new_problems