From 4f0057b363762698f90eea05de154e62b6883bd0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 19:41:08 +0200 Subject: [PATCH] Update. --- main.py | 1 - problem.py | 2 +- quizz_machine.py | 7 ++----- sky.py | 7 +++---- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 524715a..402e6e5 100755 --- a/main.py +++ b/main.py @@ -396,7 +396,6 @@ def create_c_quizzes( new_c_quizzes[:72], args.result_dir, f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}", - log_string, ) return sum_logits / sum_nb_c_quizzes diff --git a/problem.py b/problem.py index 8d973eb..95a9c41 100755 --- a/problem.py +++ b/problem.py @@ -13,7 +13,7 @@ class Problem: pass # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix, logger): + def save_quizzes(self, input, result_dir, filename_prefix): pass # returns a pair (forward_tokens, backward_token) diff --git a/quizz_machine.py b/quizz_machine.py index d63855c..be34847 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -98,7 +98,7 @@ class QuizzMachine: if result_dir is not None: self.problem.save_quizzes( - self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger + self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes" ) def batches(self, split="train", desc=None): @@ -206,10 +206,7 @@ class QuizzMachine: ) self.problem.save_quizzes( - result[:72], - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - logger, + result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}" ) return main_test_accuracy diff --git a/sky.py b/sky.py index ec476a6..1e6ed4d 100755 --- a/sky.py +++ b/sky.py @@ -343,14 +343,13 @@ class Sky(problem.Problem): result.append("".join([self.token2char[v] for v in s])) return result - def save_image(self, input, result_dir, filename, logger): + def save_image(self, input, result_dir, filename): img = self.seq2img(input.to("cpu")) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - def save_quizzes(self, input, result_dir, filename_prefix, logger): - self.save_image(input, result_dir, filename_prefix + ".png", logger) + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") ###################################################################### -- 2.20.1