From 952dabe800dba1bb7bb295e3600022ea2fba0b66 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 13:53:31 +0200 Subject: [PATCH] Update. --- tasks.py => quizz_machine.py | 16 ++++++++-------- world.py => sku.py | 0 2 files changed, 8 insertions(+), 8 deletions(-) rename tasks.py => quizz_machine.py (96%) rename world.py => sku.py (100%) diff --git a/tasks.py b/quizz_machine.py similarity index 96% rename from tasks.py rename to quizz_machine.py index 50ded2c..d8ebad8 100755 --- a/tasks.py +++ b/quizz_machine.py @@ -82,12 +82,12 @@ class Task: ###################################################################### -import world +import sky class QuizzMachine(Task): def save_image(self, input, result_dir, filename, logger): - img = world.seq2img(input.to("cpu"), self.height, self.width) + img = sky.seq2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") @@ -115,11 +115,11 @@ class QuizzMachine(Task): self.height = 6 self.width = 8 - self.train_w_quizzes = world.generate_seq( + self.train_w_quizzes = sky.generate_seq( nb_train_samples, height=self.height, width=self.width ).to(device) - self.test_w_quizzes = world.generate_seq( + self.test_w_quizzes = sky.generate_seq( nb_test_samples, height=self.height, width=self.width ).to(device) @@ -250,7 +250,7 @@ class QuizzMachine(Task): input = self.train_w_quizzes if for_train else self.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() - input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to( self.device ) @@ -324,9 +324,9 @@ class QuizzMachine(Task): l = self.height * self.width direction = c_quizzes[:, l : l + 1] - direction = world.token_forward * ( - direction == world.token_backward - ) + world.token_backward * (direction == world.token_forward) + direction = sky.token_forward * ( + direction == sky.token_backward + ) + sky.token_backward * (direction == sky.token_forward) reverse_c_quizzes = torch.cat( [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 ) diff --git a/world.py b/sku.py similarity index 100% rename from world.py rename to sku.py -- 2.39.5