projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
359cf44
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sun, 7 Jan 2024 15:10:41 +0000
(16:10 +0100)
committer
François Fleuret
<francois@fleuret.org>
Sun, 7 Jan 2024 15:10:41 +0000
(16:10 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
7aa8578
..
fb08a3a
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-669,8
+669,8
@@
class Caterpillar(nn.Module):
n = torch.arange(N, device=X.device)[:, None, None, None]
t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
n = torch.arange(N, device=X.device)[:, None, None, None]
t = torch.arange(t0, t1, device=X.device)[None, None, :, None]
- dv = torch.arange(DV)[None, None, None, :]
- dk = torch.arange(DK)[None, None, None, :]
+ dv = torch.arange(DV
, device=X.device
)[None, None, None, :]
+ dk = torch.arange(DK
, device=X.device
)[None, None, None, :]
u = (
torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
u = (
torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL
@@
-679,13
+679,17
@@
class Caterpillar(nn.Module):
src_time = t - u - t0
src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
src_time = t - u - t0
src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device)
- mask_V = (torch.rand(N, CH, t1 - t0, DV) <= self.proba_flashback).long()
+ mask_V = (
+ torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback
+ ).long()
self.rec_V[:, :, t0:t1] = (
mask_V * V[n, src_head, src_time, dv]
+ (1 - mask_V) * self.rec_V[:, :, t0:t1]
)
self.rec_V[:, :, t0:t1] = (
mask_V * V[n, src_head, src_time, dv]
+ (1 - mask_V) * self.rec_V[:, :, t0:t1]
)
- mask_K = (torch.rand(N, CH, t1 - t0, DK) <= self.proba_flashback).long()
+ mask_K = (
+ torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback
+ ).long()
self.rec_K[:, :, t0:t1] = (
mask_K * K[n, src_head, src_time, dk]
+ (1 - mask_K) * self.rec_K[:, :, t0:t1]
self.rec_K[:, :, t0:t1] = (
mask_K * K[n, src_head, src_time, dk]
+ (1 - mask_K) * self.rec_K[:, :, t0:t1]