X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=cdf8f9e48e583341ab68cdb30e90f5220095dd6e;hb=17c63771f2ca82ce39d8406e377ace2015fe69fc;hp=27173e1fabd9eaaf64cd79b5929313a52c9e6497;hpb=0b4e1b17c903e83463f5018bfbfb2fe2dda2636e;p=culture.git diff --git a/tasks.py b/tasks.py index 27173e1..cdf8f9e 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + temperature, deterministic_synthesis, forbidden_tokens=None, logit_biases=None, @@ -44,17 +45,22 @@ def masked_inplace_autoregression( t = model.training model.eval() + sum_logits = 0 + for input, ar_mask in batches: - model.masked_inplace_autoregression( - input, - ar_mask, - deterministic_synthesis, - forbidden_tokens, - logit_biases, + sum_logits += model.masked_inplace_autoregression( + input=input, + ar_mask=ar_mask, + temperature=temperature, + deterministic_synthesis=deterministic_synthesis, + forbidden_tokens=forbidden_tokens, + forced_biases=logit_biases, ) model.train(t) + return sum_logits + ###################################################################### @@ -79,7 +85,7 @@ import world class World(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(input.to("cpu"), self.height, self.width) + img = world.seq2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") @@ -167,11 +173,12 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -205,11 +212,12 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -223,6 +231,14 @@ class World(Task): return main_test_accuracy + def renew_samples(self, nb, for_train=True): + input = self.train_input if for_train else self.test_input + nb = min(nb, input.size(0)) + input[:-nb] = input[nb:].clone() + input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + self.device + ) + def store_new_quizzes(self, new_quizzes, for_train=True): if for_train: self.train_quizzes.append(new_quizzes) @@ -237,6 +253,7 @@ class World(Task): nb, model, other_models, + desired_average_logits=None, ): ############################################################### # Generate quizzes with model @@ -246,16 +263,32 @@ class World(Task): ) ar_mask = torch.full(quizzes.size(), 1, device=self.device) - masked_inplace_autoregression( - model, - self.batch_size, - quizzes, - ar_mask, + sum_logits = masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=quizzes, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) + average_logits = sum_logits / quizzes.numel() + + if desired_average_logits is not None: + temperature = average_logits / desired_average_logits + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=quizzes, + ar_mask=ar_mask, + temperature=temperature, + deterministic_synthesis=False, + progress_bar_desc="creating quizzes", + device=self.device, + ) + ############################################################### # Create the reverse quizzes @@ -280,10 +313,11 @@ class World(Task): result = quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", device=self.device, @@ -294,10 +328,11 @@ class World(Task): reverse_result = reverse_quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - reverse_result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=reverse_result, + ar_mask=ar_mask, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", device=self.device, @@ -316,4 +351,4 @@ class World(Task): for k in nb_correct: f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0) + return quizzes, nb_correct.sum(dim=0), average_logits