X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=39372f314787479be04888d859218f94334f3940;hb=35a16ac34a3f1af05323a9cb3823fbcfd74035a4;hp=2a1833d831fe2fdfa0c097f7f4b4e8a9327ef88b;hpb=694923fcfa606cf8fc9ee6066ef4bbdea27003ce;p=culture.git diff --git a/tasks.py b/tasks.py index 2a1833d..39372f3 100755 --- a/tasks.py +++ b/tasks.py @@ -22,6 +22,7 @@ def masked_inplace_autoregression( batch_size, input, ar_mask, + summed_logits, temperature, deterministic_synthesis, forbidden_tokens=None, @@ -41,16 +42,15 @@ def masked_inplace_autoregression( total=(input.size(0) + batch_size - 1) // batch_size, ) - sum_logits = 0 - with torch.autograd.no_grad(): t = model.training model.eval() for input, ar_mask in batches: - sum_logits += model.masked_inplace_autoregression( + model.masked_inplace_autoregression( input=input, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=deterministic_synthesis, forbidden_tokens=forbidden_tokens, @@ -59,8 +59,6 @@ def masked_inplace_autoregression( model.train(t) - return sum_logits - ###################################################################### @@ -180,6 +178,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -219,6 +218,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc=None, @@ -266,23 +266,27 @@ class World(Task): ) ar_mask = torch.full(quizzes.size(), 1, device=self.device) + summed_logits = torch.empty(nb, device=self.device) temperature = 1 d_temperature = 1 while True: - sum_logits = masked_inplace_autoregression( + summed_logits[...] = 0 + + masked_inplace_autoregression( model=model, batch_size=self.batch_size, input=quizzes, ar_mask=ar_mask, + summed_logits=summed_logits, temperature=temperature, deterministic_synthesis=False, progress_bar_desc="creating quizzes", device=self.device, ) - average_logits = sum_logits / quizzes.size(0) + average_logits = summed_logits.mean() logger(f"{average_logits=} {desired_average_logits=}") @@ -290,14 +294,16 @@ class World(Task): break # Oh man that's ugly - if average_logits > desired_average_logits: + if average_logits < desired_average_logits: if d_temperature < 0: d_temperature *= -0.5 temperature += d_temperature - else: + elif average_logits > desired_average_logits * 0.95: if d_temperature > 0: d_temperature *= -0.5 temperature += d_temperature + else: + break logger(f"chaging temperature to {temperature}") @@ -329,6 +335,7 @@ class World(Task): batch_size=self.batch_size, input=result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving quizzes", @@ -344,6 +351,7 @@ class World(Task): batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, + summed_logits=None, temperature=1.0, deterministic_synthesis=True, progress_bar_desc="solving reversed quizzes", @@ -363,4 +371,4 @@ class World(Task): # for k in nb_correct: # f.write(f"{k}\n") - return quizzes, nb_correct.sum(dim=0), sum_logits + return quizzes, nb_correct.sum(dim=0), summed_logits.mean()