From: François Fleuret Date: Sun, 4 Aug 2024 08:36:25 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=53f7d335a94e4ca9b5bc83f1212525809eb5270a;p=culture.git Update. --- diff --git a/main.py b/main.py index 8f3568f..63f6cce 100755 --- a/main.py +++ b/main.py @@ -372,7 +372,8 @@ def run_tests(model, quiz_machine, local_device=main_device): for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): input = input.to(local_device) - output = model(mygpt.BracketedSequence(input)).x + sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices + output = model(mygpt.BracketedSequence(input), sigma).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -417,7 +418,8 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - output = model(mygpt.BracketedSequence(input)).x + sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices + output = model(mygpt.BracketedSequence(input), sigma).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) diff --git a/mygpt.py b/mygpt.py index 15ed80e..b1cdf4d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -90,22 +90,38 @@ class AddPositionalEncoding(nn.Module): # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - def forward(self, bs): + def forward(self, bs, sigma=None): if bs.first == 0: - t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ - :, None - ] - j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ - None, : - ] - k = j % 2 - self.pe = torch.sin( - t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k - ) + if sigma is None: + t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ + None, :, None + ] + j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + None, None, : + ] + k = j % 2 + self.pe = torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + else: + t_out = sigma[:, :, None] + t_in = F.pad(t_out, (0, 0, 1, -1), value=-1) + j = torch.arange( + bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device + )[None, None, :] + k = j % 2 + pe_out = torch.sin( + t_out / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + pe_in = torch.sin( + t_in / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + self.pe = torch.cat([pe_in, pe_out], dim=2) + self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] + bs.slice() + self.pe[:, bs.first : bs.first + bs.nb] ) return BracketedSequence(self.cache_y, bs.first, bs.nb) @@ -262,11 +278,12 @@ class MyGPT(nn.Module): self.temperature = 1.0 - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), + self.embedding = CacheWrapper( + nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout) ) + self.positional_encoding = AddPositionalEncoding(len_max) + trunk_blocks = [] for b in range(nb_blocks): @@ -331,12 +348,19 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, bs): + def forward(self, bs, sigma=None): + if sigma is not None: + bs.x = bs.x.gather(dim=1, index=sigma) bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) + bs = self.positional_encoding(bs, sigma) bs = self.trunk(bs) bs = self.readout(bs) bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature + if sigma is not None: + bs.x.scatter_( + dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone() + ) return bs def encode(self, bs): @@ -351,7 +375,6 @@ class MyGPT(nn.Module): def partial_forward(self, bs, start_layer=None, end_layer=None): if start_layer is None: - # print(f"GENERATE {bs.first} {bs.first+bs.nb}") bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) if end_layer is not None: