From c3016d368197199fe0a75f644c593db2d3081da2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 14:57:41 +0200 Subject: [PATCH] Update. --- main.py | 195 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 139 insertions(+), 56 deletions(-) diff --git a/main.py b/main.py index 289bae4..c6d76ee 100755 --- a/main.py +++ b/main.py @@ -750,6 +750,18 @@ from mygpt import ( ) +class MultiEmbedding(nn.Module): + def __init__(self, nb_values, dim): + super().__init__() + self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values]) + + def forward(self, x): + y = 0 + for f, z in zip(self.embeddings, x.split(1, dim=2)): + y = y + f(z[:, :, 0]) + return y + + class MyAttentionAE(nn.Module): def __init__( self, @@ -766,11 +778,14 @@ class MyAttentionAE(nn.Module): assert dim_model % nb_heads == 0 - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + self.embedding = CacheWrapper( + nn.Sequential( + MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout) + ), ) - self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) + # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) + self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) trunk_blocks = [] @@ -859,12 +874,14 @@ def test_ae(local_device=main_device): dropout=args.dropout, ).to(main_device) + pure_noise = True + data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)), - (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)), - (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), + (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) @@ -894,30 +911,37 @@ def test_ae(local_device=main_device): targets = input - # mask_diffusion_noise = (mask_generate == 1) & ( - # torch.rand(mask_generate.size(), device=mask_generate.device) - # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) - # ) + input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - # mask_diffusion_noise = mask_diffusion_noise.long() + if pure_noise: + mask_diffusion_noise = torch.rand( + mask_generate.size(), device=mask_generate.device + ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) - # input = ( - # 1 - mask_diffusion_noise - # ) * input + mask_diffusion_noise * torch.randint( - # quiz_machine.problem.nb_colors, input.size(), device=input.device - # ) + mask_diffusion_noise = mask_diffusion_noise.long() - # ------------------------------ - input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - model.eval() - for it in range(torch.randint(5, (1,)).item()): - logits = model(mygpt.BracketedSequence(input)).x - dist = torch.distributions.categorical.Categorical(logits=logits) - input = (1 - mask_generate) * input + mask_generate * dist.sample() - model.train() - # ----------------------------- + input = input + mask_generate * mask_diffusion_noise * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + else: + model.eval() + for it in range(torch.randint(5, (1,)).item()): + logits = model( + mygpt.BracketedSequence( + torch.cat( + [input[:, :, None], mask_generate[:, :, None]], dim=2 + ) + ) + ).x + dist = torch.distributions.categorical.Categorical(logits=logits) + input = (1 - mask_generate) * input + mask_generate * dist.sample() + model.train() - output = model(mygpt.BracketedSequence(input)).x + output = model( + mygpt.BracketedSequence( + torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) + ) + ).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -947,31 +971,48 @@ def test_ae(local_device=main_device): ): targets = input - # mask_diffusion_noise = (mask_generate == 1) & ( - # torch.rand(mask_generate.size(), device=mask_generate.device) - # <= torch.rand( - # (mask_generate.size(0), 1), device=mask_generate.device - # ) - # ) - - # mask_diffusion_noise = mask_diffusion_noise.long() - - # input = ( - # 1 - mask_diffusion_noise - # ) * input + mask_diffusion_noise * torch.randint( - # quiz_machine.problem.nb_colors, input.size(), device=input.device - # ) - - # ------------------------------ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - for it in range(torch.randint(5, (1,)).item()): - logits = model(mygpt.BracketedSequence(input)).x - dist = torch.distributions.categorical.Categorical(logits=logits) - input = (1 - mask_generate) * input + mask_generate * dist.sample() - # ----------------------------- - - output = model(mygpt.BracketedSequence(input)).x + if pure_noise: + mask_diffusion_noise = torch.rand( + mask_generate.size(), device=mask_generate.device + ) <= torch.rand( + (mask_generate.size(0), 1), device=mask_generate.device + ) + + mask_diffusion_noise = mask_diffusion_noise.long() + + input = ( + input + + mask_generate + * mask_diffusion_noise + * torch.randint( + quiz_machine.problem.nb_colors, + input.size(), + device=input.device, + ) + ) + else: + for it in range(torch.randint(5, (1,)).item()): + logits = model( + mygpt.BracketedSequence( + torch.cat( + [input[:, None], mask_generate[:, None]], dim=1 + ) + ) + ).x + dist = torch.distributions.categorical.Categorical( + logits=logits + ) + input = ( + 1 - mask_generate + ) * input + mask_generate * dist.sample() + + output = model( + mygpt.BracketedSequence( + torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) + ) + ).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -992,11 +1033,43 @@ def test_ae(local_device=main_device): input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - result = (1 - mask_generate) * input + if pure_noise: + mask_diffusion_noise = torch.rand( + mask_generate.size(), device=mask_generate.device + ) <= torch.rand( + (mask_generate.size(0), 1), device=mask_generate.device + ) + + mask_diffusion_noise = mask_diffusion_noise.long() + + input = ( + input + + mask_generate + * mask_diffusion_noise + * torch.randint( + quiz_machine.problem.nb_colors, + input.size(), + device=input.device, + ) + ) + else: + for it in range(torch.randint(5, (1,)).item()): + logits = model( + mygpt.BracketedSequence( + torch.cat( + [input[:, :, None], mask_generate[:, :, None]], + dim=2, + ) + ) + ).x + dist = torch.distributions.categorical.Categorical( + logits=logits + ) + input = ( + 1 - mask_generate + ) * input + mask_generate * dist.sample() - # + mask_generate * torch.randint( - # quiz_machine.problem.nb_colors, input.size(), device=input.device - # ) + result = input not_converged = torch.full( (result.size(0),), True, device=result.device @@ -1004,7 +1077,17 @@ def test_ae(local_device=main_device): for it in range(100): pred_result = result.clone() - logits = model(mygpt.BracketedSequence(result[not_converged])).x + logits = model( + mygpt.BracketedSequence( + torch.cat( + [ + result[not_converged, :, None], + mask_generate[:, :, None], + ], + dim=2, + ) + ) + ).x dist = torch.distributions.categorical.Categorical(logits=logits) update = (1 - mask_generate[not_converged]) * input[ not_converged -- 2.39.5