From 17a885dc2c98bc5370dcc2ebd32493dcebdd4225 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 09:40:59 +0300 Subject: [PATCH] Update. --- main.py | 6 ++-- mygpt.py | 26 +++++++++++++++- quizz_machine.py | 80 +++++++++++++++++++----------------------------- wireworld.py | 19 +++++++++--- 4 files changed, 74 insertions(+), 57 deletions(-) diff --git a/main.py b/main.py index 6e5545c..11eb8fd 100755 --- a/main.py +++ b/main.py @@ -434,9 +434,9 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - log_string( - f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_create}" - ) + nv = [recorded[n][-1].size(0) for n in recorded.keys()] + + log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}") # concatenate and shuffle for n in recorded.keys(): diff --git a/mygpt.py b/mygpt.py index 7119c7a..d0fda7e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -201,6 +201,26 @@ class QKVAttention(nn.Module): ############################## +class NoiseInjector(nn.Module): + def __init__(self): + super().__init__() + self.noise_std = 0.0 + + def forward(self, x): + if self.noise_std > 0: + x = x + torch.randn(x.size(), device=x.device) * self.noise_std + return x + + +def set_noise_injection(model, noise_std): + for m in model.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = noise_std + + +############################## + + class MyGPT(nn.Module): def __init__( self, @@ -228,7 +248,10 @@ class MyGPT(nn.Module): for b in range(nb_blocks): trunk_blocks += [ WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), + CacheWrapper( + nn.LayerNorm((dim_model,)), + NoiseInjector(), + ), QKVAttention( dim_in=dim_model, dim_qk=dim_keys, @@ -241,6 +264,7 @@ class MyGPT(nn.Module): WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), + NoiseInjector(), nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), nn.Linear(in_features=dim_hidden, out_features=dim_model), diff --git a/quizz_machine.py b/quizz_machine.py index 6cad6a1..84bb558 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -12,6 +12,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +import mygpt from mygpt import BracketedSequence ###################################################################### @@ -20,7 +21,7 @@ from mygpt import BracketedSequence class Gang(nn.Module): def __init__(self, models, nb_models_for_generation, mode="groupthink"): super().__init__() - self.models = models + self.models = nn.ModuleList(models) self.nb_models_for_generation = nb_models_for_generation self.mode = mode @@ -383,58 +384,39 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - # bracketing of the temperature to get the target logproba if - # min_ave_seq_logproba is not None + warnings.warn("noise injection", RuntimeWarning) + temperature = 1 + noise_std = torch.rand(1).item() + self.logger(f"{noise_std=}") + mygpt.set_noise_injection(model_for_generation, noise_std) - temperature = 2 - d_temperature = 1 / 3 - - while True: - seq_logproba[...] = 0 - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_prompt, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - # progress_bar_desc="sampling c_quizzes", - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_prompt, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=False, + # progress_bar_desc="sampling c_quizzes", + device=self.device, + ) - ave_seq_logproba = seq_logproba.mean() + ave_seq_logproba = seq_logproba.mean() - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=ar_mask_solve, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=True, - # progress_bar_desc="sampling c_quizzes", - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + # progress_bar_desc="sampling c_quizzes", + device=self.device, + ) - # If we do not have target logprobs, get out now - if min_ave_seq_logproba is None: - break - - # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba: - if d_temperature > 0: - d_temperature *= -1 / 3 - temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba * 0.99: - if d_temperature < 0: - d_temperature *= -1 / 3 - temperature += d_temperature - else: - break - - self.logger(f"changing temperature to {temperature}") + mygpt.set_noise_injection(model_for_generation, 0.0) return c_quizzes, seq_logproba.mean() diff --git a/wireworld.py b/wireworld.py index 65b12ad..8257cad 100755 --- a/wireworld.py +++ b/wireworld.py @@ -62,9 +62,10 @@ class Wireworld(problem.Problem): def generate_frame_sequences_hard(self, nb): frame_sequences = [] + nb_frames = (self.nb_iterations - 1) * self.speed + 1 result = torch.full( - (nb * 4, self.nb_iterations * self.speed, self.height, self.width), + (nb * 4, nb_frames, self.height, self.width), self.token_empty, ) @@ -116,8 +117,8 @@ class Wireworld(problem.Problem): result[n, 0, i + vi, j + vj] = self.token_tail break - if torch.rand(1) < 0.75: - break + # if torch.rand(1) < 0.75: + break weight = torch.full((1, 1, 3, 3), 1.0) @@ -130,7 +131,10 @@ class Wireworld(problem.Problem): # tail->conductor # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - for l in range(self.nb_iterations * self.speed - 1): + nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) + valid = nb_heads > 0 + + for l in range(nb_frames - 1): nb_head_neighbors = ( F.conv2d( input=(result[:, l] == self.token_head).float()[:, None, :, :], @@ -153,6 +157,13 @@ class Wireworld(problem.Problem): + (1 - mask_1_or_2_heads) * self.token_conductor ) ) + pred_nb_heads = nb_heads + nb_heads = ( + (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) + ) + valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) + + result = result[valid] result = result[ :, torch.arange(self.nb_iterations, device=result.device) * self.speed -- 2.20.1