3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
13 from torch.nn import functional as F
15 ##############################
18 class WithResidual(nn.Module):
19 def __init__(self, *f):
21 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
27 ##############################
30 class AddPositionalEncoding(nn.Module):
31 def __init__(self, len_max):
33 self.len_max = len_max
35 # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
37 t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
38 j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
40 pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
44 ##############################
47 class QKVAttention(nn.Module):
49 self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
54 return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
57 self.attention_dropout = attention_dropout
59 self.w_q = randw(nb_heads, dim_qk, dim_in)
60 self.w_k = randw(nb_heads, dim_qk, dim_in)
61 self.w_v = randw(nb_heads, dim_v, dim_in)
62 self.w_o = randw(dim_v * nb_heads, dim_in)
64 def forward(self, x_q, x_kv=None):
68 q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
69 k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
70 v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
72 a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
75 forbidden_attention = (
76 torch.arange(a.size(2), device=q.device)[None, None, :, None]
77 < torch.arange(a.size(3), device=q.device)[None, None, None, :]
79 a = a.masked_fill(forbidden_attention, float("-inf"))
82 a = F.dropout(a, self.attention_dropout, self.training)
83 y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2)
90 ##############################
93 class MyGPT(nn.Module):
108 assert dim_model % nb_heads == 0
110 self.embedding = nn.Sequential(
111 nn.Embedding(vocabulary_size, dim_model),
113 AddPositionalEncoding(len_max),
118 for _ in range(nb_blocks):
121 nn.LayerNorm((dim_model,)),
125 dim_v=dim_model // nb_heads,
128 attention_dropout=dropout,
132 nn.LayerNorm((dim_model,)),
133 nn.Linear(in_features=dim_model, out_features=dim_hidden),
135 nn.Linear(in_features=dim_hidden, out_features=dim_model),
140 self.trunk = nn.Sequential(*trunk_blocks)
142 self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
144 with torch.no_grad():
145 for m in self.modules():
146 if isinstance(m, nn.Embedding):
147 m.weight.normal_(mean=0, std=2e-2)
148 elif isinstance(m, nn.LayerNorm):
152 def forward(self, x):
153 x = F.pad(x, (1, -1))
154 x = self.embedding(x)
160 ######################################################################
162 if __name__ == "__main__":
163 print("Basic check.")
166 x = torch.randint(vocabulary_size, (25, 100))
169 vocabulary_size=vocabulary_size,
180 ######################################################################