From 5332c56acd44d7049f3fbb33a8643482e0c71f4d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 25 Aug 2023 22:33:48 +0200 Subject: [PATCH] Update. --- grid.py | 28 ++++++++++++++++------------ tasks.py | 4 ++-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/grid.py b/grid.py index f72c8e3..5b28914 100755 --- a/grid.py +++ b/grid.py @@ -19,7 +19,7 @@ name_colors = ["red", "yellow", "blue", "green", "white", "purple"] class GridFactory: def __init__( self, - size=4, + size=6, max_nb_items=4, max_nb_transformations=3, nb_questions=4, @@ -143,14 +143,14 @@ class GridFactory: def generate_scene_and_questions(self): while True: while True: - scene = self.generate_scene() - true = self.all_properties(scene) + start_scene = self.generate_scene() + true = self.all_properties(start_scene) if len(true) >= self.nb_questions: break - start = self.grid_positions(scene) + start = self.grid_positions(start_scene) - scene, transformations = self.random_transformations(scene) + scene, transformations = self.random_transformations(start_scene) # transformations=[] @@ -185,7 +185,7 @@ class GridFactory: + questions ) - return scene, result + return start_scene, scene, result def generate_samples(self, nb, progress_bar=None): result = [] @@ -195,7 +195,7 @@ class GridFactory: r = progress_bar(r) for _ in r: - result.append(self.generate_scene_and_questions()[1]) + result.append(self.generate_scene_and_questions()[2]) return result @@ -207,13 +207,17 @@ if __name__ == "__main__": grid_factory = GridFactory() - start_time = time.perf_counter() - samples = grid_factory.generate_samples(10000) - end_time = time.perf_counter() - print(f"{len(samples) / (end_time - start_time):.02f} samples per second") + # start_time = time.perf_counter() + # samples = grid_factory.generate_samples(10000) + # end_time = time.perf_counter() + # print(f"{len(samples) / (end_time - start_time):.02f} samples per second") - scene, questions = grid_factory.generate_scene_and_questions() + start_scene, scene, questions = grid_factory.generate_scene_and_questions() + print("-- Original scene -----------------------------") + grid_factory.print_scene(start_scene) + print("-- Transformed scene --------------------------") grid_factory.print_scene(scene) + print("-- Sequence -----------------------------------") print(questions) ###################################################################### diff --git a/tasks.py b/tasks.py index 2c2f914..24c13fe 100755 --- a/tasks.py +++ b/tasks.py @@ -1539,8 +1539,8 @@ class Grid(Task): nb_total = ar_mask.sum().item() nb_correct = ((correct == result).long() * ar_mask).sum().item() - logger(f"test_performance {nb_total=} {nb_correct=}") - logger(f"main_test_accuracy {nb_correct / nb_total}") + logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") + logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") ###################################################################### -- 2.39.5