X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=ee06c25e1923ce1a69880384526b1a1636064a5f;hb=9707563cb32ed2335dc4a6edddaa0ebe9cfd1243;hp=5edb472024342565801fee1f7de48f55bb17756a;hpb=702e672dcf9ebcfad11ae4034e64117f2c67ead5;p=culture.git diff --git a/tasks.py b/tasks.py index 5edb472..ee06c25 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + summed_logits, temperature, deterministic_synthesis, forbidden_tokens=None, @@ -45,12 +46,11 @@ def masked_inplace_autoregression( t = model.training model.eval() - sum_logits = 0 - for input, ar_mask in batches: - sum_logits += model.masked_inplace_autoregression( + model.masked_inplace_autoregression( input=input, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=deterministic_synthesis, forbidden_tokens=forbidden_tokens, @@ -59,8 +59,6 @@ def masked_inplace_autoregression( model.train(t) - return sum_logits - ###################################################################### @@ -154,6 +152,9 @@ class World(Task): self.nb_batch_samples_world = input.size(0) self.nb_batch_samples_quizzes = 0 + # Shuffle + input = input[torch.randperm(input.size(0))] + if desc is None: desc = f"epoch-{split}" for batch in tqdm.tqdm( @@ -177,6 +178,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -216,6 +218,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -261,41 +264,49 @@ class World(Task): quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(quizzes.size(), 1, device=self.device) - sum_logits = masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=quizzes, - ar_mask=ar_mask, - temperature=1.0, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) - - # Should not be necessary though, the autoregression is done - # in eval mode - sum_logits = sum_logits.detach() + ar_mask = torch.full(quizzes.size(), 1, device=self.device) + summed_logits = torch.empty(nb, device=self.device) - average_logits = sum_logits / quizzes.numel() + temperature = 1 + d_temperature = 1 - # It's a bit brutal to do it twice, we should probably have a - # moving average and apply it right away + while True: + summed_logits[...] = 0 - if desired_average_logits is not None: - temperature = average_logits / desired_average_logits masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=quizzes, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) + average_logits = summed_logits.mean() + + logger(f"{average_logits=} {desired_average_logits=}") + + if desired_average_logits is None: + break + + # Oh man that's ugly + if average_logits < desired_average_logits * 1.1: + if d_temperature > 0: + d_temperature *= -0.5 + temperature += d_temperature + elif average_logits > desired_average_logits: + if d_temperature < 0: + d_temperature *= -0.5 + temperature += d_temperature + else: + break + + logger(f"chaging temperature to {temperature}") + ############################################################### # Create the reverse quizzes @@ -324,6 +335,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", @@ -339,6 +351,7 @@ class World(Task): batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", @@ -353,9 +366,9 @@ class World(Task): nb_correct = torch.cat(nb_correct, dim=0) - filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - with open(filename, "w") as f: - for k in nb_correct: - f.write(f"{k}\n") + # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + # with open(filename, "w") as f: + # for k in nb_correct: + # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0), average_logits + return quizzes, nb_correct.sum(dim=0), summed_logits.mean()