######################################################################
-import world
+import sky
class QuizzMachine(Task):
def save_image(self, input, result_dir, filename, logger):
- img = world.seq2img(input.to("cpu"), self.height, self.width)
+ img = sky.seq2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
self.height = 6
self.width = 8
- self.train_w_quizzes = world.generate_seq(
+ self.train_w_quizzes = sky.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_w_quizzes = world.generate_seq(
+ self.test_w_quizzes = sky.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
- input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+ input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
l = self.height * self.width
direction = c_quizzes[:, l : l + 1]
- direction = world.token_forward * (
- direction == world.token_backward
- ) + world.token_backward * (direction == world.token_forward)
+ direction = sky.token_forward * (
+ direction == sky.token_backward
+ ) + sky.token_backward * (direction == sky.token_forward)
reverse_c_quizzes = torch.cat(
[c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)