self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_in, dim_v * nb_heads)
+ self.w_o = randw(dim_v * nb_heads, dim_in)
def forward(self, x_q, x_kv = None):
if x_kv is None: x_kv = x_q
self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
def forward(self, x):
- x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
+ x = F.pad(x, (1, 0))
x = self.embedding(x)
x = self.trunk(x)
x = self.readout(x)
model = MyGPT(
vocabulary_size = vocabulary_size,
- dim_model = 16, dim_keys = 50, dim_hidden = 100,
+ dim_model = 18, dim_keys = 50, dim_hidden = 100,
nb_heads = 2, nb_blocks = 3,
dropout = 0.1
)