From c8979c695ad584c54d605b8f183e5d2e99f2d1cc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 18:11:35 +0200 Subject: [PATCH] Update. --- quizz_machine.py | 29 +++++++++++------------------ sky.py | 25 ++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index a3da365..28b94d1 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -70,15 +70,6 @@ import sky class QuizzMachine: - def save_image(self, input, result_dir, filename, logger): - img = self.sky.seq2img(input.to("cpu")) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - - def save_quizzes(self, input, result_dir, filename_prefix, logger): - self.save_image(input, result_dir, filename_prefix + ".png", logger) - def make_ar_mask(self, input): b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 return b.long()[None, :].expand_as(input) @@ -94,12 +85,12 @@ class QuizzMachine: ): super().__init__() - self.sky = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) + self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) self.batch_size = batch_size self.device = device - self.train_w_quizzes = self.sky.generate_seq(nb_train_samples).to(device) - self.test_w_quizzes = self.sky.generate_seq(nb_test_samples).to(device) + self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device) + self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device) self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1 @@ -107,7 +98,7 @@ class QuizzMachine: self.test_c_quizzes = [] if result_dir is not None: - self.save_quizzes( + self.problem.save_quizzes( self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger ) @@ -215,7 +206,7 @@ class QuizzMachine: device=self.device, ) - self.save_quizzes( + self.problem.save_quizzes( result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", @@ -228,7 +219,7 @@ class QuizzMachine: input = self.train_w_quizzes if for_train else self.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() - input[-nb:] = self.sky.generate_seq(nb).to(self.device) + input[-nb:] = self.problem.generate_seq(nb).to(self.device) def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: @@ -298,11 +289,13 @@ class QuizzMachine: ############################################################### # Create the reverse quizzes + token_forward, token_backward = self.problem.direction_tokens() + l = (c_quizzes.size(1) - 1) // 2 direction = c_quizzes[:, l : l + 1] - direction = self.sky.token_forward * ( - direction == self.sky.token_backward - ) + self.sky.token_backward * (direction == self.sky.token_forward) + direction = self.problem.token_forward * ( + direction == self.problem.token_backward + ) + self.problem.token_backward * (direction == self.problem.token_forward) reverse_c_quizzes = torch.cat( [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 ) diff --git a/sky.py b/sky.py index a90e37d..cb25ea0 100755 --- a/sky.py +++ b/sky.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm +import math, sys, tqdm, os import torch, torchvision @@ -15,6 +15,17 @@ from torch.nn import functional as F ###################################################################### +class Problem: + def generate_seq(self, nb_train_samples): + pass + + def save_quizzes(self, input, result_dir, filename_prefix, logger): + pass + + def direction_tokens(self): + pass + + class Sky: colors = torch.tensor( [ @@ -48,6 +59,9 @@ class Sky: self.nb_birds = nb_birds self.nb_iterations = nb_iterations + def direction_tokens(self): + return self.token_forward, self.token_backward + def generate_seq(self, nb, return_iterations=False): pairs = [] kept_iterations = [] @@ -338,6 +352,15 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result + def save_image(self, input, result_dir, filename, logger): + img = self.seq2img(input.to("cpu")) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) + logger(f"wrote {image_name}") + + def save_quizzes(self, input, result_dir, filename_prefix, logger): + self.save_image(input, result_dir, filename_prefix + ".png", logger) + ###################################################################### -- 2.39.5