X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=a6b257c78c0e81071ce02d99c495d96aa65ade58;hb=HEAD;hp=d6879dc08a29f05cac1998bc1ab16e46db07821c;hpb=52c6bd98650c846459f10e8303dd2e6c7ba2a68f;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index d6879dc..a6b257c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,7 +14,8 @@ from torch.nn import functional as F ############################## -class Residual(nn.Module): + +class WithResidual(nn.Module): def __init__(self, *f): super().__init__() self.f = f[0] if len(f) == 1 else nn.Sequential(*f) @@ -22,29 +23,31 @@ class Residual(nn.Module): def forward(self, x): return x + self.f(x) + ############################## -class PositionalEncoding(nn.Module): + +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): - t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] - j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] - k = j%2 - pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) return x + pe + ############################## + class QKVAttention(nn.Module): - def __init__(self, - dim_in, dim_qk, dim_v, - nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__( + self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + ): super().__init__() def randw(*d): @@ -58,36 +61,47 @@ class QKVAttention(nn.Module): self.w_v = randw(nb_heads, dim_v, dim_in) self.w_o = randw(dim_v * nb_heads, dim_in) - def forward(self, x_q, x_kv = None): - if x_kv is None: x_kv = x_q + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q - q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) - k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) - v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k) + v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v) - a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) if self.causal: - mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ - < torch.arange(a.size(3), device = q.device)[None, None, None, :] - a = a.masked_fill(mask, float('-inf')) + forbidden_attention = ( + torch.arange(a.size(2), device=q.device)[None, None, :, None] + < torch.arange(a.size(3), device=q.device)[None, None, None, :] + ) + a = a.masked_fill(forbidden_attention, float("-inf")) - a = a.softmax(dim = 3) + a = a.softmax(dim=3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2) y = y @ self.w_o return y + ############################## + class MyGPT(nn.Module): - def __init__(self, - vocabulary_size, - dim_model, dim_keys, dim_hidden, - nb_heads, nb_blocks, - dropout = 0.0, len_max = 1e5): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): super().__init__() @@ -96,57 +110,69 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max), + AddPositionalEncoding(len_max), ) - trunk_blocks = [ ] + trunk_blocks = [] for _ in range(nb_blocks): trunk_blocks += [ - Residual( + WithResidual( nn.LayerNorm((dim_model,)), QKVAttention( - dim_in = dim_model, - dim_qk = dim_keys, - dim_v = dim_model // nb_heads, - nb_heads = nb_heads, - causal = True, attention_dropout = dropout + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=True, + attention_dropout=dropout, ), ), - Residual( + WithResidual( nn.LayerNorm((dim_model,)), - nn.Linear(in_features = dim_model, out_features = dim_hidden), + nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), - nn.Linear(in_features = dim_hidden, out_features = dim_model), + nn.Linear(in_features=dim_hidden, out_features=dim_model), nn.Dropout(dropout), ), ] self.trunk = nn.Sequential(*trunk_blocks) - self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) def forward(self, x): - x = F.pad(x, (1, 0)) + x = F.pad(x, (1, -1)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - x = F.pad(x, (0, 0, 0, -1)) return x + ###################################################################### -if __name__ == '__main__': - print('Basic check.') +if __name__ == "__main__": + print("Basic check.") vocabulary_size = 10 x = torch.randint(vocabulary_size, (25, 100)) model = MyGPT( - vocabulary_size = vocabulary_size, - dim_model = 18, dim_keys = 50, dim_hidden = 100, - nb_heads = 2, nb_blocks = 3, - dropout = 0.1 + vocabulary_size=vocabulary_size, + dim_model=18, + dim_keys=50, + dim_hidden=100, + nb_heads=2, + nb_blocks=3, + dropout=0.1, ) y = model(x)