From: François Fleuret <francois@fleuret.org>
Date: Sat, 6 Jan 2024 13:49:59 +0000 (+0100)
Subject: Update.
X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=03100792df9e52b739bbe4692bed6c4f6b575242;p=mygptrnn.git

Update.
---

diff --git a/mygpt.py b/mygpt.py
index 9bacaff..c061eb4 100755
--- a/mygpt.py
+++ b/mygpt.py
@@ -545,6 +545,8 @@ class Caterpillar(nn.Module):
             torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
         ).sigmoid()
 
+        G = F.dropout(G, self.attention_dropout, self.training)
+
         V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
         K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)