From 6d1dcbd012266474c9a2118302cc365d458752fb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 3 Aug 2024 17:39:09 +0200 Subject: [PATCH] Update. --- grids.py | 80 ++++++++++++++++++++++++++++++++++++++++++-------------- main.py | 40 +++++++++++++++++++++++++--- 2 files changed, 98 insertions(+), 22 deletions(-) diff --git a/grids.py b/grids.py index 05c3057..f195144 100755 --- a/grids.py +++ b/grids.py @@ -257,6 +257,9 @@ class Grids(problem.Problem): self.token_B = self.token_f_A + 1 self.token_f_B = self.token_B + 1 + self.nb_rec_max = 5 + self.rfree = torch.tensor([]) + self.l2tok = { "A": self.token_A, "f_A": self.token_f_A, @@ -574,15 +577,39 @@ class Grids(problem.Problem): ###################################################################### + def contact_matrices(self, rn, ri, rj, rz): + return ( + ( + ( + ( + (ri[:, :, None, 0] == ri[:, None, :, 1] + 1) + | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0]) + ) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) + ) + | ( + ( + (rj[:, :, None, 0] == rj[:, None, :, 1] + 1) + | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0]) + ) + & (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + ) + ) + # & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) + def sample_rworld_states(self, N=1000): - nb_rec_max = 5 while True: - rn = torch.randint(nb_rec_max - 1, (N,)) + 2 - ri = torch.randint(self.height, (N, nb_rec_max, 2)).sort(dim=2).values - rj = torch.randint(self.width, (N, nb_rec_max, 2)).sort(dim=2).values - rz = torch.randint(2, (N, nb_rec_max)) - rc = torch.randint(self.nb_colors - 1, (N, nb_rec_max)) + 1 - n = torch.arange(nb_rec_max) + rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2 + ri = torch.randint(self.height, (N, self.nb_rec_max, 2)).sort(dim=2).values + rj = torch.randint(self.width, (N, self.nb_rec_max, 2)).sort(dim=2).values + rz = torch.randint(2, (N, self.nb_rec_max)) + rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1 + n = torch.arange(self.nb_rec_max) nb_collisions = ( ( (ri[:, :, None, 0] <= ri[:, None, :, 1]) @@ -607,17 +634,34 @@ class Grids(problem.Problem): self.rj = rj[no_collision] self.rz = rz[no_collision] self.rc = rc[no_collision] + + nb_contact = ( + contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1) + ) + + self.rcontact = nb_contact > 0 + self.rfree = torch.full((self.rn.size(0),), True) + break - def task_rworld_change_color(self, A, f_A, B, f_B): - nb_rec = 3 - c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + def get_recworld_state(self): + if not self.rfree.any(): + self.sample_rworld_states() + k = torch.arange(self.rn.size(0))[self.rfree] + k = k[torch.randint(k.size(0), (1,))].item() + self.rfree[k] = False + return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k] + + def draw_state(self, X, rn, ri, rj, rz, rc): + for n in sorted(list(range(rn)), key=lambda n: rz[n].item()): + X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n] + + def task_recworld_immobile(self, A, f_A, B, f_B): for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec, prevent_overlap=True) - for n in range(nb_rec): - i1, j1, i2, j2 = r[n] - X[i1:i2, j1:j2] = c[n] - f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] + rn, ri, rj, rz, rc = self.get_recworld_state() + self.draw_state(X, rn, ri, rj, rz, rc) + ri += 1 + self.draw_state(f_X, rn, ri, rj, rz, rc) ###################################################################### @@ -1703,8 +1747,6 @@ if __name__ == "__main__": # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) grids = Grids() - grids.sample_rworld_states() - exit(0) # nb = 5 # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) @@ -1746,7 +1788,7 @@ if __name__ == "__main__": # for t in grids.all_tasks: - for t in [grids.task_science_tag]: + for t in [grids.task_recworld_immobile]: print(t.__name__) w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( @@ -1756,7 +1798,7 @@ if __name__ == "__main__": comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], ) - # exit(0) + exit(0) nb = 1000 diff --git a/main.py b/main.py index 36b58e2..9a8bd43 100755 --- a/main.py +++ b/main.py @@ -478,29 +478,63 @@ c_quizzes_procedure = [ def save_additional_results(models, science_w_quizzes): + # Save generated quizzes with the successive steps + for model in models: recorder = [] c_quizzes = quiz_machine.generate_c_quizzes( - 32, + 64, model_for_generation=model, procedure=c_quizzes_procedure, recorder=recorder, ) + ## + + probas = 0 + + for a in range(args.nb_averaging_rounds): + # This is nb_quizzes x nb_models + + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas += seq_logproba.exp() + + probas /= args.nb_averaging_rounds + + comments = [] + + for l in seq_logproba: + comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + + ## + c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) - nrow = c_quizzes.size(1) + nb_steps = c_quizzes.size(1) c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) + # We have comments only for the final quiz, not the successive + # steps, so we have to add nb_steps-1 empty comments + + steps_comments = [] + for c in comments: + steps_comments += [""] * (nb_steps - 1) + [c] + filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, quizzes=c_quizzes, predicted_parts=predicted_parts, - nrow=nrow, + comments=steps_comments, + nrow=nb_steps * 2, # two quiz per row ) log_string(f"wrote {filename}") -- 2.20.1