Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 08:51:12 +0000 (10:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 08:51:12 +0000 (10:51 +0200)
mygpt.py

index 7827757..041d28c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -26,7 +26,7 @@ class BSQ(nn.Module):
         self.L = L
 
     def forward(self, input, indexes=False):
-        norm = input.pow(2).sum(dim=1, keepdim=True).sqrt()
+        norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
         u = input / norm
 
         if indexes: