X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quizz_machine.py;h=a3da36525c01fc49185f4d080f929e5162bd59f0;hb=b76e3f632315c63dbd8f11a53b187f23057e4e1f;hp=43fd868f229e8edb5aeb5d86742cdf564ce7f57b;hpb=6dbc18a5db82b12b06212841426896412e8bd6de;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 43fd868..a3da365 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -71,7 +71,7 @@ import sky class QuizzMachine: def save_image(self, input, result_dir, filename, logger): - img = sky.seq2img(input.to("cpu"), self.height, self.width) + img = self.sky.seq2img(input.to("cpu")) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") @@ -94,18 +94,12 @@ class QuizzMachine: ): super().__init__() + self.sky = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 - self.train_w_quizzes = sky.generate_seq( - nb_train_samples, height=self.height, width=self.width - ).to(device) - - self.test_w_quizzes = sky.generate_seq( - nb_test_samples, height=self.height, width=self.width - ).to(device) + self.train_w_quizzes = self.sky.generate_seq(nb_train_samples).to(device) + self.test_w_quizzes = self.sky.generate_seq(nb_test_samples).to(device) self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1 @@ -234,9 +228,7 @@ class QuizzMachine: input = self.train_w_quizzes if for_train else self.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() - input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to( - self.device - ) + input[-nb:] = self.sky.generate_seq(nb).to(self.device) def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: @@ -258,7 +250,7 @@ class QuizzMachine: # Generate quizzes with model c_quizzes = torch.empty( - nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 + nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) @@ -306,11 +298,11 @@ class QuizzMachine: ############################################################### # Create the reverse quizzes - l = self.height * self.width + l = (c_quizzes.size(1) - 1) // 2 direction = c_quizzes[:, l : l + 1] - direction = sky.token_forward * ( - direction == sky.token_backward - ) + sky.token_backward * (direction == sky.token_forward) + direction = self.sky.token_forward * ( + direction == self.sky.token_backward + ) + self.sky.token_backward * (direction == self.sky.token_forward) reverse_c_quizzes = torch.cat( [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 )