From 540927e2268ae78ef6ee41259a685ee008bb9c68 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 10 Aug 2024 10:51:12 +0200 Subject: [PATCH] Update. --- mygpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mygpt.py b/mygpt.py index 7827757..041d28c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -26,7 +26,7 @@ class BSQ(nn.Module): self.L = L def forward(self, input, indexes=False): - norm = input.pow(2).sum(dim=1, keepdim=True).sqrt() + norm = input.pow(2).sum(dim=2, keepdim=True).sqrt() u = input / norm if indexes: -- 2.20.1