Added default configurations and reformated with black.
[mygpt.git] / mygpt.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9
10 import torch
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ##############################
16
17
18 class WithResidual(nn.Module):
19     def __init__(self, *f):
20         super().__init__()
21         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
22
23     def forward(self, x):
24         return x + self.f(x)
25
26
27 ##############################
28
29
30 class AddPositionalEncoding(nn.Module):
31     def __init__(self, len_max):
32         super().__init__()
33         self.len_max = len_max
34
35     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
36     def forward(self, x):
37         t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
38         j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
39         k = j % 2
40         pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
41         return x + pe
42
43
44 ##############################
45
46
47 class QKVAttention(nn.Module):
48     def __init__(
49         self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
50     ):
51         super().__init__()
52
53         def randw(*d):
54             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
55
56         self.causal = causal
57         self.attention_dropout = attention_dropout
58
59         self.w_q = randw(nb_heads, dim_qk, dim_in)
60         self.w_k = randw(nb_heads, dim_qk, dim_in)
61         self.w_v = randw(nb_heads, dim_v, dim_in)
62         self.w_o = randw(dim_v * nb_heads, dim_in)
63
64     def forward(self, x_q, x_kv=None):
65         if x_kv is None:
66             x_kv = x_q
67
68         q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
69         k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
70         v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
71
72         a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
73
74         if self.causal:
75             forbidden_attention = (
76                 torch.arange(a.size(2), device=q.device)[None, None, :, None]
77                 < torch.arange(a.size(3), device=q.device)[None, None, None, :]
78             )
79             a = a.masked_fill(forbidden_attention, float("-inf"))
80
81         a = a.softmax(dim=3)
82         a = F.dropout(a, self.attention_dropout, self.training)
83         y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2)
84
85         y = y @ self.w_o
86
87         return y
88
89
90 ##############################
91
92
93 class MyGPT(nn.Module):
94     def __init__(
95         self,
96         vocabulary_size,
97         dim_model,
98         dim_keys,
99         dim_hidden,
100         nb_heads,
101         nb_blocks,
102         dropout=0.0,
103         len_max=1e5,
104     ):
105
106         super().__init__()
107
108         assert dim_model % nb_heads == 0
109
110         self.embedding = nn.Sequential(
111             nn.Embedding(vocabulary_size, dim_model),
112             nn.Dropout(dropout),
113             AddPositionalEncoding(len_max),
114         )
115
116         trunk_blocks = []
117
118         for _ in range(nb_blocks):
119             trunk_blocks += [
120                 WithResidual(
121                     nn.LayerNorm((dim_model,)),
122                     QKVAttention(
123                         dim_in=dim_model,
124                         dim_qk=dim_keys,
125                         dim_v=dim_model // nb_heads,
126                         nb_heads=nb_heads,
127                         causal=True,
128                         attention_dropout=dropout,
129                     ),
130                 ),
131                 WithResidual(
132                     nn.LayerNorm((dim_model,)),
133                     nn.Linear(in_features=dim_model, out_features=dim_hidden),
134                     nn.ReLU(),
135                     nn.Linear(in_features=dim_hidden, out_features=dim_model),
136                     nn.Dropout(dropout),
137                 ),
138             ]
139
140         self.trunk = nn.Sequential(*trunk_blocks)
141
142         self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
143
144         with torch.no_grad():
145             for m in self.modules():
146                 if isinstance(m, nn.Embedding):
147                     m.weight.normal_(mean=0, std=2e-2)
148                 elif isinstance(m, nn.LayerNorm):
149                     m.bias.zero_()
150                     m.weight.fill_(1.0)
151
152     def forward(self, x):
153         x = F.pad(x, (1, -1))
154         x = self.embedding(x)
155         x = self.trunk(x)
156         x = self.readout(x)
157         return x
158
159
160 ######################################################################
161
162 if __name__ == "__main__":
163     print("Basic check.")
164
165     vocabulary_size = 10
166     x = torch.randint(vocabulary_size, (25, 100))
167
168     model = MyGPT(
169         vocabulary_size=vocabulary_size,
170         dim_model=18,
171         dim_keys=50,
172         dim_hidden=100,
173         nb_heads=2,
174         nb_blocks=3,
175         dropout=0.1,
176     )
177
178     y = model(x)
179
180 ######################################################################