X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=50ded2c016aebf2c3e982d418a0343aa98d4ff72;hb=336130cc923761658029a0af9d5862d59405d47a;hp=27173e1fabd9eaaf64cd79b5929313a52c9e6497;hpb=0b4e1b17c903e83463f5018bfbfb2fe2dda2636e;p=culture.git diff --git a/tasks.py b/tasks.py index 27173e1..50ded2c 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,8 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + seq_logproba, + temperature, deterministic_synthesis, forbidden_tokens=None, logit_biases=None, @@ -30,7 +32,11 @@ def masked_inplace_autoregression( ): assert input.size() == ar_mask.size() - batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + batches = zip( + input.split(batch_size), + ar_mask.split(batch_size), + seq_logproba.split(batch_size), + ) if progress_bar_desc is not None: batches = tqdm.tqdm( @@ -44,13 +50,15 @@ def masked_inplace_autoregression( t = model.training model.eval() - for input, ar_mask in batches: + for input, ar_mask, seq_logproba in batches: model.masked_inplace_autoregression( - input, - ar_mask, - deterministic_synthesis, - forbidden_tokens, - logit_biases, + input=input, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=deterministic_synthesis, + forbidden_tokens=forbidden_tokens, + forced_biases=logit_biases, ) model.train(t) @@ -77,13 +85,16 @@ class Task: import world -class World(Task): +class QuizzMachine(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(input.to("cpu"), self.height, self.width) + img = world.seq2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") + def save_quizzes(self, input, result_dir, filename_prefix, logger): + self.save_image(input, result_dir, filename_prefix + ".png", logger) + def make_ar_mask(self, input): b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 return b.long()[None, :].expand_as(input) @@ -104,49 +115,55 @@ class World(Task): self.height = 6 self.width = 8 - self.train_input = world.generate_seq( + self.train_w_quizzes = world.generate_seq( nb_train_samples, height=self.height, width=self.width ).to(device) - self.test_input = world.generate_seq( + self.test_w_quizzes = world.generate_seq( nb_test_samples, height=self.height, width=self.width ).to(device) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1 - self.train_quizzes = [] - self.test_quizzes = [] + self.train_c_quizzes = [] + self.test_c_quizzes = [] if result_dir is not None: - self.save_image( - self.train_input[:72], result_dir, f"world_train.png", logger + self.save_quizzes( + self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger ) def batches(self, split="train", desc=None): assert split in {"train", "test"} if split == "train": - input = self.train_input - quizzes = self.train_quizzes + w_quizzes = self.train_w_quizzes + c_quizzes = self.train_c_quizzes else: - input = self.test_input - quizzes = self.test_quizzes + w_quizzes = self.test_w_quizzes + c_quizzes = self.test_c_quizzes - if len(quizzes) > 0: - quizzes = torch.cat(quizzes, dim=0) - if quizzes.size(0) > input.size(0) // 2: - i = torch.randperm(input.size(0))[: input.size(0) // 2] - quizzes = quizzes[i] + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] - i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)] - input = input[i] + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = quizzes.size(0) + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) - input = torch.cat([input, quizzes], dim=0) + input = torch.cat([w_quizzes, c_quizzes], dim=0) else: - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = 0 + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 + + # Shuffle + input = input[torch.randperm(input.size(0))] if desc is None: desc = f"epoch-{split}" @@ -165,13 +182,16 @@ class World(Task): input = input[:nmax] ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) + seq_logproba = torch.empty(input.size(0), device=self.device) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -183,13 +203,13 @@ class World(Task): return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_input) + train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger) + test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -200,36 +220,47 @@ class World(Task): ############################## - input = self.test_input[:96] + input = self.test_w_quizzes[:96] ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) + seq_logproba = torch.empty(input.size(0), device=self.device) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) - self.save_image( + self.save_quizzes( result[:72], result_dir, - f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", + f"culture_prediction_{n_epoch:04d}_{model.id:02d}", logger, ) return main_test_accuracy - def store_new_quizzes(self, new_quizzes, for_train=True): + def renew_w_quizzes(self, nb, for_train=True): + input = self.train_w_quizzes if for_train else self.test_w_quizzes + nb = min(nb, input.size(0)) + input[:-nb] = input[nb:].clone() + input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + self.device + ) + + def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: - self.train_quizzes.append(new_quizzes) + self.train_c_quizzes.append(new_c_quizzes) else: - self.test_quizzes.append(new_quizzes) + self.test_c_quizzes.append(new_c_quizzes) - def create_new_quizzes( + def create_c_quizzes( self, n_epoch, result_dir, @@ -237,38 +268,71 @@ class World(Task): nb, model, other_models, + min_ave_seq_logproba, ): ############################################################### # Generate quizzes with model - quizzes = torch.empty( + c_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(quizzes.size(), 1, device=self.device) - masked_inplace_autoregression( - model, - self.batch_size, - quizzes, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) + ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) + seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + + temperature = 1 + d_temperature = 1 + + while True: + seq_logproba[...] = 0 + + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=False, + progress_bar_desc="sampling c_quizzes", + device=self.device, + ) + + ave_seq_logproba = seq_logproba.mean() + + logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") + + if min_ave_seq_logproba is None: + break + + # Oh man that's ugly + if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if d_temperature > 0: + d_temperature *= -1 / 3 + temperature += d_temperature + elif ave_seq_logproba > min_ave_seq_logproba: + if d_temperature < 0: + d_temperature *= -1 / 3 + temperature += d_temperature + else: + break + + logger(f"chaging temperature to {temperature}") ############################################################### # Create the reverse quizzes l = self.height * self.width - direction = quizzes[:, l : l + 1] + direction = c_quizzes[:, l : l + 1] direction = world.token_forward * ( direction == world.token_backward ) + world.token_backward * (direction == world.token_forward) - reverse_quizzes = torch.cat( - [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 + reverse_c_quizzes = torch.cat( + [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 ) - ar_mask = self.make_ar_mask(quizzes) + ar_mask = self.make_ar_mask(c_quizzes) + seq_logproba = torch.empty(ar_mask.size(0), device=self.device) ############################################################### # Check how many of the other models can solve them in both @@ -277,43 +341,42 @@ class World(Task): nb_correct = [] for m in other_models: - result = quizzes.clone() + result = c_quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving quizzes", + progress_bar_desc="solving c_quizzes", device=self.device, ) - correct = (quizzes == result).long().min(dim=-1).values + correct = (c_quizzes == result).long().min(dim=-1).values - reverse_result = reverse_quizzes.clone() + reverse_result = reverse_c_quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - reverse_result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=reverse_result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving reversed quizzes", + progress_bar_desc="solving reversed c_quizzes", device=self.device, ) reverse_correct = ( - (reverse_quizzes == reverse_result).long().min(dim=-1).values + (reverse_c_quizzes == reverse_result).long().min(dim=-1).values ) nb_correct.append((correct * reverse_correct)[None, :]) - nb_correct = torch.cat(nb_correct, dim=0) - - filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - with open(filename, "w") as f: - for k in nb_correct: - f.write(f"{k}\n") + nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0) - return quizzes, nb_correct.sum(dim=0) + return c_quizzes, nb_correct, seq_logproba.mean()