projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
de99e48
)
Update.
author
François Fleuret
<francois@fleuret.org>
Thu, 18 Jan 2024 23:19:12 +0000
(
00:19
+0100)
committer
François Fleuret
<francois@fleuret.org>
Thu, 18 Jan 2024 23:19:12 +0000
(
00:19
+0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
5451584
..
fb24b9a
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-644,17
+644,18
@@
class Caterpillar(nn.Module):
warnings.warn("gate dropout", RuntimeWarning)
warnings.warn("gate dropout", RuntimeWarning)
- # kill = (
- # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
- # ).float()
-
+ # Pick a point in each of the NxHxR timeline and set this
+ # entry and the following to 1
kill = (
torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
).cumsum(dim=3)
kill = (
torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
).cumsum(dim=3)
+
+ # Keep these mask for only some of the NxHxR
kill = kill * (
torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
)
kill = kill * (
torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout
)
+ # The coefficient to keep are the complementary
mask = 1 - kill
masked_next_V, masked_next_K = recurrence(G * mask, V, K)
mask = 1 - kill
masked_next_V, masked_next_K = recurrence(G * mask, V, K)
@@
-674,8
+675,8
@@
class Caterpillar(nn.Module):
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
- # We build tensors NxHxTx
F
xL where N is the sample index, H
- # the head, T the time,
F
the row in the caterpillar, and L
+ # We build tensors NxHxTx
R
xL where N is the sample index, H
+ # the head, T the time,
R
the row in the caterpillar, and L
# the column in the caterpillar
windowed_V = moving_window(
# the column in the caterpillar
windowed_V = moving_window(
@@
-689,7
+690,7
@@
class Caterpillar(nn.Module):
# We have an attention score for each of the RxL values
ar = torch.einsum(
# We have an attention score for each of the RxL values
ar = torch.einsum(
- "nhtd,n
ftld->nhtf
l",
+ "nhtd,n
rtld->nhtr
l",
Q,
windowed_K,
) / math.sqrt(DK)
Q,
windowed_K,
) / math.sqrt(DK)