From: Francois Fleuret Date: Wed, 27 Jul 2022 14:22:26 +0000 (+0200) Subject: OCDC X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0c51561334475af559cda12627388c9d5567a55f;p=mygpt.git OCDC --- diff --git a/mygpt.py b/mygpt.py index ab16e1e..43711b3 100755 --- a/mygpt.py +++ b/mygpt.py @@ -107,7 +107,8 @@ class MyGPT(nn.Module): nn.LayerNorm(dim_model), QKVAttention( dim_in = dim_model, - dim_qk = dim_keys, dim_v = dim_model // nb_heads, + dim_qk = dim_keys, + dim_v = dim_model // nb_heads, nb_heads = nb_heads, causal = True, attention_dropout = dropout ),