X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=ee06c25e1923ce1a69880384526b1a1636064a5f;hb=9707563cb32ed2335dc4a6edddaa0ebe9cfd1243;hp=125432361bb4899eb3a64479339c130622f0c5ed;hpb=90eab15841632ef4f7bd88d2a7cbbb2426bf736a;p=culture.git diff --git a/tasks.py b/tasks.py index 1254323..ee06c25 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,8 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + summed_logits, + temperature, deterministic_synthesis, forbidden_tokens=None, logit_biases=None, @@ -46,11 +48,13 @@ def masked_inplace_autoregression( for input, ar_mask in batches: model.masked_inplace_autoregression( - input, - ar_mask, - deterministic_synthesis, - forbidden_tokens, - logit_biases, + input=input, + ar_mask=ar_mask, + summed_logits=summed_logits, + temperature=temperature, + deterministic_synthesis=deterministic_synthesis, + forbidden_tokens=forbidden_tokens, + forced_biases=logit_biases, ) model.train(t) @@ -79,7 +83,7 @@ import world class World(Task): def save_image(self, input, result_dir, filename, logger): - img = world.sample2img(input.to("cpu"), self.height, self.width) + img = world.seq2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") @@ -101,8 +105,8 @@ class World(Task): self.batch_size = batch_size self.device = device - self.height = 7 - self.width = 9 + self.height = 6 + self.width = 8 self.train_input = world.generate_seq( nb_train_samples, height=self.height, width=self.width @@ -112,13 +116,6 @@ class World(Task): nb_test_samples, height=self.height, width=self.width ).to(device) - # print() - # for a in world.seq2str(self.train_input): - # print(a) - # for a in world.seq2str(self.test_input): - # print(a) - # exit(0) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 self.train_quizzes = [] @@ -155,6 +152,9 @@ class World(Task): self.nb_batch_samples_world = input.size(0) self.nb_batch_samples_quizzes = 0 + # Shuffle + input = input[torch.randperm(input.size(0))] + if desc is None: desc = f"epoch-{split}" for batch in tqdm.tqdm( @@ -174,11 +174,13 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + summed_logits=None, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -212,11 +214,13 @@ class World(Task): result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, + model=model, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + summed_logits=None, + temperature=1.0, + deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, device=self.device, ) @@ -230,6 +234,14 @@ class World(Task): return main_test_accuracy + def renew_samples(self, nb, for_train=True): + input = self.train_input if for_train else self.test_input + nb = min(nb, input.size(0)) + input[:-nb] = input[nb:].clone() + input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + self.device + ) + def store_new_quizzes(self, new_quizzes, for_train=True): if for_train: self.train_quizzes.append(new_quizzes) @@ -244,6 +256,7 @@ class World(Task): nb, model, other_models, + desired_average_logits=None, ): ############################################################### # Generate quizzes with model @@ -251,17 +264,48 @@ class World(Task): quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) + ar_mask = torch.full(quizzes.size(), 1, device=self.device) + summed_logits = torch.empty(nb, device=self.device) - masked_inplace_autoregression( - model, - self.batch_size, - quizzes, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) + temperature = 1 + d_temperature = 1 + + while True: + summed_logits[...] = 0 + + masked_inplace_autoregression( + model=model, + batch_size=self.batch_size, + input=quizzes, + ar_mask=ar_mask, + summed_logits=summed_logits, + temperature=temperature, + deterministic_synthesis=False, + progress_bar_desc="creating quizzes", + device=self.device, + ) + + average_logits = summed_logits.mean() + + logger(f"{average_logits=} {desired_average_logits=}") + + if desired_average_logits is None: + break + + # Oh man that's ugly + if average_logits < desired_average_logits * 1.1: + if d_temperature > 0: + d_temperature *= -0.5 + temperature += d_temperature + elif average_logits > desired_average_logits: + if d_temperature < 0: + d_temperature *= -0.5 + temperature += d_temperature + else: + break + + logger(f"chaging temperature to {temperature}") ############################################################### # Create the reverse quizzes @@ -287,10 +331,12 @@ class World(Task): result = quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=result, + ar_mask=ar_mask, + summed_logits=None, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", device=self.device, @@ -301,10 +347,12 @@ class World(Task): reverse_result = reverse_quizzes.clone() masked_inplace_autoregression( - m, - self.batch_size, - reverse_result, - ar_mask, + model=m, + batch_size=self.batch_size, + input=reverse_result, + ar_mask=ar_mask, + summed_logits=None, + temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", device=self.device, @@ -318,9 +366,9 @@ class World(Task): nb_correct = torch.cat(nb_correct, dim=0) - filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - with open(filename, "w") as f: - for k in nb_correct: - f.write(f"{k}\n") + # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + # with open(filename, "w") as f: + # for k in nb_correct: + # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0) + return quizzes, nb_correct.sum(dim=0), summed_logits.mean()