class GridFactory:
def __init__(
self,
- size=4,
+ size=6,
max_nb_items=4,
max_nb_transformations=3,
nb_questions=4,
def generate_scene_and_questions(self):
while True:
while True:
- scene = self.generate_scene()
- true = self.all_properties(scene)
+ start_scene = self.generate_scene()
+ true = self.all_properties(start_scene)
if len(true) >= self.nb_questions:
break
- start = self.grid_positions(scene)
+ start = self.grid_positions(start_scene)
- scene, transformations = self.random_transformations(scene)
+ scene, transformations = self.random_transformations(start_scene)
# transformations=[]
+ questions
)
- return scene, result
+ return start_scene, scene, result
def generate_samples(self, nb, progress_bar=None):
result = []
r = progress_bar(r)
for _ in r:
- result.append(self.generate_scene_and_questions()[1])
+ result.append(self.generate_scene_and_questions()[2])
return result
grid_factory = GridFactory()
- start_time = time.perf_counter()
- samples = grid_factory.generate_samples(10000)
- end_time = time.perf_counter()
- print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+ # start_time = time.perf_counter()
+ # samples = grid_factory.generate_samples(10000)
+ # end_time = time.perf_counter()
+ # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
- scene, questions = grid_factory.generate_scene_and_questions()
+ start_scene, scene, questions = grid_factory.generate_scene_and_questions()
+ print("-- Original scene -----------------------------")
+ grid_factory.print_scene(start_scene)
+ print("-- Transformed scene --------------------------")
grid_factory.print_scene(scene)
+ print("-- Sequence -----------------------------------")
print(questions)
######################################################################
nb_total = ar_mask.sum().item()
nb_correct = ((correct == result).long() * ar_mask).sum().item()
- logger(f"test_performance {nb_total=} {nb_correct=}")
- logger(f"main_test_accuracy {nb_correct / nb_total}")
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
######################################################################