From 17c63771f2ca82ce39d8406e377ace2015fe69fc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jun 2024 23:01:56 +0200 Subject: [PATCH] Update. --- main.py | 34 ++++++++++++++++++---- mygpt.py | 7 +++++ tasks.py | 89 ++++++++++++++++++++++++++++++++++++-------------------- world.py | 84 ++++++++++++++++++++++++++++++++-------------------- 4 files changed, 146 insertions(+), 68 deletions(-) diff --git a/main.py b/main.py index 683c07d..a021a71 100755 --- a/main.py +++ b/main.py @@ -335,23 +335,30 @@ def create_quizzes( task, nb_for_train=1000, nb_for_test=100, + desired_average_logits=None, ): kept = [] + nb_generated_tokens, sum_logits = 0, 0 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: - new_quizzes, nb_correct = task.create_new_quizzes( + nb_to_generate = 4 * (nb_for_train + nb_for_test) + new_quizzes, nb_correct, average_logits = task.create_new_quizzes( n_epoch=n_epoch, result_dir=args.result_dir, logger=log_string, - nb=4 * (nb_for_train + nb_for_test), + nb=nb_to_generate, model=model, other_models=other_models, + desired_average_logits=desired_average_logits, ) - print(nb_correct) + nb_generated_tokens += new_quizzes.numel() + sum_logits += average_logits * new_quizzes.numel() to_keep = new_quizzes[nb_correct == len(other_models) - 1] - log_string(f"keep {to_keep.size(0)} quizzes") + log_string( + f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)" + ) kept.append(to_keep) new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] @@ -366,6 +373,8 @@ def create_quizzes( log_string, ) + return sum_logits / nb_generated_tokens + ###################################################################### @@ -403,6 +412,8 @@ if args.check: nb_new_quizzes_for_train = 10 nb_new_quizzes_for_test = 10 +desired_average_logits = None + for n_epoch in range(args.nb_epochs): a = [(model.id, float(model.main_test_accuracy)) for model in models] a.sort(key=lambda p: p[0]) @@ -428,18 +439,31 @@ for n_epoch in range(args.nb_epochs): # test it run_tests(model, task, deterministic_synthesis=False) + log_string( + f"test_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" + ) + if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes: other_models = models.copy() other_models.remove(model) - create_quizzes( + average_logits = create_quizzes( model, other_models, task, nb_for_train=nb_new_quizzes_for_train, nb_for_test=nb_new_quizzes_for_test, + desired_average_logits=desired_average_logits, ) + # We keep the first average logits as a reference + if desired_average_logits is None: + desired_average_logits = average_logits + else: + log_string( + f"desired_average_logits {desired_average_logits} average_logits {average_logits}" + ) + # We update everyone for model in models: run_tests(model, task, deterministic_synthesis=False) diff --git a/mygpt.py b/mygpt.py index 131c822..3bb3519 100755 --- a/mygpt.py +++ b/mygpt.py @@ -279,10 +279,12 @@ class MyGPT(nn.Module): self, input, ar_mask, + temperature=1.0, deterministic_synthesis=False, forbidden_tokens=None, forced_biases=None, ): + sum_logits = 0 to_generate = (ar_mask.sum(0) > 0).nonzero() if to_generate.min() > 0: self( @@ -300,8 +302,13 @@ class MyGPT(nn.Module): else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() + sum_logits += logits.log_softmax(dim=-1)[ + torch.arange(t_next.size(0)), t_next + ] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + return sum_logits + def record_attention(self, v=True): for m in self.modules(): if isinstance(m, QKVAttention): diff --git a/tasks.py b/tasks.py index 2c88333..cdf8f9e 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + temperature, deterministic_synthesis, forbidden_tokens=None, logit_biases=None, @@ -44,17 +45,22 @@ def masked_inplace_autoregression( t = model.training model.eval() + sum_logits = 0 + for input, ar_mask in batches: - model.masked_inplace_autoregression( - input, - ar_mask, - deterministic_synthesis, - forbidden_tokens, - logit_biases, + sum_logits += model.masked_inplace_autoregression( + input=input, + ar_mask=ar_mask, + temperature=temperature, + deterministic_synthesis=deterministic_synthesis, + forbidden_tokens=forbidden_tokens, + forced_biases=logit_biases, ) model.train(t) + return sum_logits + ###################################################################### @@ -79,7 +85,7 @@ import world class World(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(input.to("cpu"), self.height, self.width) + img = world.seq2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") @@ -167,11 +173,12 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -205,11 +212,12 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -245,6 +253,7 @@ class World(Task): nb, model, other_models, + desired_average_logits=None, ): ############################################################### # Generate quizzes with model @@ -254,16 +263,32 @@ class World(Task): ) ar_mask = torch.full(quizzes.size(), 1, device=self.device) - masked_inplace_autoregression( - model, - self.batch_size, - quizzes, - ar_mask, + sum_logits = masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=quizzes, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) + average_logits = sum_logits / quizzes.numel() + + if desired_average_logits is not None: + temperature = average_logits / desired_average_logits + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=quizzes, + ar_mask=ar_mask, + temperature=temperature, + deterministic_synthesis=False, + progress_bar_desc="creating quizzes", + device=self.device, + ) + ############################################################### # Create the reverse quizzes @@ -288,10 +313,11 @@ class World(Task): result = quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", device=self.device, @@ -302,10 +328,11 @@ class World(Task): reverse_result = reverse_quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - reverse_result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=reverse_result, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", device=self.device, @@ -324,4 +351,4 @@ class World(Task): for k in nb_correct: f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0) + return quizzes, nb_correct.sum(dim=0), average_logits diff --git a/world.py b/world.py index 839f4ff..36aa1e9 100755 --- a/world.py +++ b/world.py @@ -41,16 +41,15 @@ token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) def generate_seq( - nb, - height, - width, - nb_birds=3, - nb_iterations=2, + nb, height, width, nb_birds=3, nb_iterations=2, return_iterations=False ): pairs = [] + kept_iterations = [] for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"): while True: + iterations = [] + f_start = torch.zeros(height, width, dtype=torch.int64) i, j, vi, vj = ( @@ -90,6 +89,7 @@ def generate_seq( f_end = f_start.clone() for l in range(nb_iterations): + iterations.append(f_end.clone()) f_end[...] = 0 nb_collisions = 0 for n in range(nb_birds): @@ -125,9 +125,12 @@ def generate_seq( f_end[i[n] - vi[n], j[n]] = c f_end[i[n], j[n] - vj[n]] = c + iterations.append(f_end.clone()) + if nb_collisions == 0: break + kept_iterations.append(iterations) pairs.append((f_start, f_end)) result = [] @@ -147,7 +150,11 @@ def generate_seq( )[None, :] ) - return torch.cat(result, dim=0) + if return_iterations: + # iterations = torch.cat([ torch.cat([ x[None, None] for x in l], dim = 1) for l in kept_iterations ], dim=0) + return torch.cat(result, dim=0), kept_iterations + else: + return torch.cat(result, dim=0) ###################################################################### @@ -219,32 +226,33 @@ def generate_seq_old( return torch.cat(result, dim=0) -def sample2img(seq, height, width, upscale=15): - f_first = seq[:, : height * width].reshape(-1, height, width) - f_second = seq[:, height * width + 1 :].reshape(-1, height, width) - direction = seq[:, height * width] +def frame2img(x, height, width, upscale=15): + x = x.reshape(-1, height, width) + m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long() + x = colors[x * m].permute(0, 3, 1, 2) + s = x.shape + x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) + x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) - def mosaic(x, upscale): - x = x.reshape(-1, height, width) - m = torch.logical_and(x >= 0, x < first_bird_token + nb_bird_tokens).long() - x = colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) - x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) + x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 + x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 + x = x[:, :, 1:, 1:] - x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 - x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 - x = x[:, :, 1:, 1:] + for n in range(m.size(0)): + for i in range(m.size(1)): + for j in range(m.size(2)): + if m[n, i, j] == 0: + for k in range(2, upscale - 2): + x[n, :, i * upscale + k, j * upscale + k] = 0 + x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0 - for n in range(m.size(0)): - for i in range(m.size(1)): - for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, upscale - 2): - x[n, :, i * upscale + k, j * upscale + k] = 0 - x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0 + return x - return x + +def seq2img(seq, height, width, upscale=15): + f_first = seq[:, : height * width].reshape(-1, height, width) + f_second = seq[:, height * width + 1 :].reshape(-1, height, width) + direction = seq[:, height * width] direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0) direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2) @@ -278,11 +286,11 @@ def sample2img(seq, height, width, upscale=15): return torch.cat( [ - mosaic(f_first, upscale), + frame2img(f_first, height, width, upscale), separator, direction_symbol, separator, - mosaic(f_second, upscale), + frame2img(f_second, height, width, upscale), ], dim=3, ) @@ -302,16 +310,28 @@ if __name__ == "__main__": height, width = 6, 8 start_time = time.perf_counter() - seq = generate_seq(nb=90, height=height, width=width) + seq, it = generate_seq( + nb=64, height=height, width=width, nb_iterations=100, return_iterations=True + ) delay = time.perf_counter() - start_time print(f"{seq.size(0)/delay:02f} samples/s") print(seq2str(seq[:4])) + for t in range(len(it[0])): + img = torch.cat([frame2img(f[t], height, width) for f in it], dim=0) + torchvision.utils.save_image( + img.float() / 255.0, + f"/tmp/frame_{t:03d}.png", + nrow=8, + padding=6, + pad_value=0, + ) + # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 - img = sample2img(seq, height, width) + img = seq2img(seq, height, width) print(img.size()) torchvision.utils.save_image( -- 2.20.1