From db698038df961f3e2f06e638a5f92cfc6fda39df Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jul 2024 23:55:57 +0200 Subject: [PATCH] Update. --- grids.py | 50 +++++++++++++++++++++++++----------- main.py | 67 ++++++++++++++++++++++++++++++++++--------------- mygpt.py | 39 +++++++++++++++------------- quiz_machine.py | 29 ++++++++++++++++++++- 4 files changed, 132 insertions(+), 53 deletions(-) diff --git a/grids.py b/grids.py index d41ec49..296c23a 100755 --- a/grids.py +++ b/grids.py @@ -722,7 +722,7 @@ class Grids(problem.Problem): return no, nq, nq_diag - def task_count(self, A, f_A, B, f_B): + def REMOVED_task_count(self, A, f_A, B, f_B): while True: error = False @@ -1022,7 +1022,7 @@ class Grids(problem.Problem): return dist * (1 - walls) # @torch.compile - def task_distance(self, A, f_A, B, f_B): + def REMOVED_task_distance(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:3] + 1 dist0 = torch.empty(self.height + 2, self.width + 2) dist1 = torch.empty(self.height + 2, self.width + 2) @@ -1085,7 +1085,7 @@ class Grids(problem.Problem): # if # @torch.compile - def task_puzzle(self, A, f_A, B, f_B): + def TOO_HARD_task_puzzle(self, A, f_A, B, f_B): S = 4 i0, j0 = (self.height - S) // 2, (self.width - S) // 2 c = torch.randperm(len(self.colors) - 1)[:4] + 1 @@ -1153,7 +1153,7 @@ class Grids(problem.Problem): if f_X[i + i0, j + j0] == c[d]: X[ii + i, jj + j] = c[d] - def task_islands(self, A, f_A, B, f_B): + def TOO_MESSY_task_islands(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: @@ -1183,7 +1183,7 @@ class Grids(problem.Problem): X[i, j] = c[1] # @torch.compile - def task_stack(self, A, f_A, B, f_B): + def TOO_HARD_task_stack(self, A, f_A, B, f_B): N = 5 c = torch.randperm(len(self.colors) - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -1228,7 +1228,7 @@ class Grids(problem.Problem): m = torch.tensor(m) return (torch.rand(m.size()) * m).long() - def task_matrices(self, A, f_A, B, f_B): + def TOO_HARD_task_matrices(self, A, f_A, B, f_B): N = 6 c = torch.randperm(len(self.colors) - 1)[:N] + 1 @@ -1244,7 +1244,7 @@ class Grids(problem.Problem): f_X[i, j + 5] = c[M2[i, j]] f_X[i + 5, j + 5] = c[P[i, j]] - def task_compute(self, A, f_A, B, f_B): + def TOO_HARD_task_compute(self, A, f_A, B, f_B): N = 6 c = torch.randperm(len(self.colors) - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -1423,7 +1423,7 @@ class Grids(problem.Problem): if accept_full or (d * (X == 0)).max() == self.height * self.width: break - def task_addition(self, A, f_A, B, f_B): + def TOO_HARD_task_addition(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:4] + 1 for X, f_X in [(A, f_A), (B, f_B)]: N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() @@ -1654,24 +1654,46 @@ if __name__ == "__main__": for t in [grids.task_science_tag]: print(t.__name__) - quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) grids.save_quizzes_as_image( "/tmp", t.__name__ + ".png", - quizzes, - comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))], + w_quizzes, + comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], ) - exit(0) + # exit(0) nb = 1000 - # for t in grids.all_tasks: - for t in [grids.task_path]: + for t in [ + grids.task_addition, + grids.task_bounce, + grids.task_compute, + grids.task_contact, + grids.task_corners, + grids.task_detect, + grids.task_fill, + grids.task_frame, + grids.task_grow, + grids.task_half_fill, + grids.task_islands, + grids.task_isometry, + grids.task_path, + grids.task_puzzle, + grids.task_replace_color, + grids.task_scale, + grids.task_stack, + grids.task_symbols, + grids.task_trajectory, + grids.task_translate, + ]: + # for t in [grids.task_path]: start_time = time.perf_counter() w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s") + grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128]) exit(0) diff --git a/main.py b/main.py index bf617b5..fbc6f42 100755 --- a/main.py +++ b/main.py @@ -105,6 +105,8 @@ parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument("--autoencoder_dim", type=int, default=-1) + ###################################################################### grids_tasks = ", ".join( @@ -449,18 +451,25 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### +lt_noisy = lambda s, logits: logits / args.temperature_hot +lt_clean = lambda s, logits: logits / args.temperature_cold + +c_quizzes_procedure_ = [ + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), lt_noisy), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), lt_clean), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), lt_clean), +] + +c_quizzes_procedure = [ + (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), lt_noisy), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), lt_clean), +] + def save_additional_results(models, science_w_quizzes): for model in models: c_quizzes = quiz_machine.generate_c_quizzes( - 128, - model_for_generation=model, - temperature_hot=args.temperature_hot, - temperature_cold=args.temperature_cold, - ) - - c_quizzes = quiz_machine.problem.reconfigure( - c_quizzes, ("A", "f_A", "B", "f_B") + 128, model_for_generation=model, procedure=c_quizzes_procedure ) quiz_machine.problem.save_quizzes_as_image( @@ -541,10 +550,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, - model_for_generation=model_for_generation, - temperature_hot=args.temperature_hot, - temperature_cold=args.temperature_cold, - to_recycle=to_recycle, + model_for_generation=model, + procedure=c_quizzes_procedure, + # to_recycle=to_recycle, ) # We discard the trivial ones, according to a criterion @@ -662,7 +670,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### -def train_auto_encoder(): +def train_autoencoder(): model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -672,10 +680,9 @@ def train_auto_encoder(): nb_blocks=args.nb_blocks, causal=False, dropout=args.dropout, + autoencoder_dim=args.autoencoder_dim, ).to(main_device) - model.make_auto_encoder(auto_encoder_dim=64) - test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -713,7 +720,7 @@ def train_auto_encoder(): log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") - filename = f"auto_encoder.pth" + filename = f"autoencoder.pth" torch.save( model.state_dict(), os.path.join(args.result_dir, filename), @@ -722,7 +729,7 @@ def train_auto_encoder(): with torch.autograd.no_grad(): model.eval() - input = test_w_quizzes[:128, -l:] + input = test_w_quizzes[0 * 128 : 1 * 128, -l:] z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device))) logits = model.decode(z_shape).x @@ -739,12 +746,32 @@ def train_auto_encoder(): q, ) - return model + input1 = test_w_quizzes[1 * 128 : 2 * 128, -l:] + input2 = test_w_quizzes[2 * 128 : 3 * 128, -l:] + z_shape1 = model.encode(mygpt.BracketedSequence(input1.to(main_device))) + z_shape2 = model.encode(mygpt.BracketedSequence(input2.to(main_device))) + z_shape = ((z_shape1[0] + z_shape2[0]) * 0.5, z_shape1[1]) + logits = model.decode(z_shape).x + + q = logits.argmax(dim=-1) + # q = q.reshape(q.size(0) // 2, 2, -1) + # input = input.reshape(input.size(0) // 2, 2, -1) + # q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1) + + q = q.reshape(q.size(0) // 4, -1) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_mix_ae_{n_epoch:04d}.png", + q, + ) + return model -# ae = train_auto_encoder() -# exit(0) +if args.autoencoder_dim > 0: + ae = train_autoencoder() + exit(0) ###################################################################### diff --git a/mygpt.py b/mygpt.py index b38cc99..fca2067 100755 --- a/mygpt.py +++ b/mygpt.py @@ -255,6 +255,7 @@ class MyGPT(nn.Module): nb_heads, nb_blocks, causal=False, + autoencoder_dim=-1, dropout=0.0, len_max=1e5, ): @@ -303,6 +304,26 @@ class MyGPT(nn.Module): nn.Linear(in_features=dim_model, out_features=vocabulary_size) ) + # ------------------------------------------------------- + if autoencoder_dim > 0: + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, autoencoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(autoencoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + # ------------------------------------------------------- + with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): @@ -318,24 +339,6 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs - def make_auto_encoder(self, auto_encoder_dim): - self.encoder = nn.Sequential( - *( - trunk_blocks[: nb_blocks // 2] - + [EncoderHead(dim_model, auto_encoder_dim)] - ) - ) - - self.decoder = nn.Sequential( - *( - [ - DecoderBottom(auto_encoder_dim, dim_model), - AddPositionalEncoding(len_max), - ] - + trunk_blocks[nb_blocks // 2 :] - ) - ) - def encode(self, bs): bs = self.embedding(bs) z = self.encoder(bs) diff --git a/quiz_machine.py b/quiz_machine.py index 083b50e..d4b463b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -433,7 +433,34 @@ class QuizMachine: ############################################################### - def generate_c_quizzes( + def generate_c_quizzes(self, nb, model_for_generation, procedure): + seq_logproba = torch.zeros(nb, device=self.device) + + c_quizzes = None + + for s, m, t in procedure: + if c_quizzes is None: + c_quizzes = self.problem.create_empty_quizzes(nb, s) + c_quizzes = c_quizzes.to(self.device) + elif s != pred_s: + c_quizzes = self.problem.reconfigure(c_quizzes, s) + pred_s = s + + self.autoregression( + model=model_for_generation, + input=c_quizzes, + ar_mask=self.make_ar_mask(c_quizzes, s, m), + seq_logproba=seq_logproba, + logit_transformer=t, + ) + + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) + + return c_quizzes.to("cpu") + + ###################################################################### + + def generate_c_quizzes_orig( self, nb, model_for_generation, -- 2.39.5