From 1fa5b661d7005daa019a89755af18f698e1cc231 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 7 Jan 2024 16:10:41 +0100 Subject: [PATCH] Update. --- mygpt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mygpt.py b/mygpt.py index 7aa8578..fb08a3a 100755 --- a/mygpt.py +++ b/mygpt.py @@ -669,8 +669,8 @@ class Caterpillar(nn.Module): n = torch.arange(N, device=X.device)[:, None, None, None] t = torch.arange(t0, t1, device=X.device)[None, None, :, None] - dv = torch.arange(DV)[None, None, None, :] - dk = torch.arange(DK)[None, None, None, :] + dv = torch.arange(DV, device=X.device)[None, None, None, :] + dk = torch.arange(DK, device=X.device)[None, None, None, :] u = ( torch.rand(N, CH, t1 - t0, 1, device=X.device).mul(t).long() // CL @@ -679,13 +679,17 @@ class Caterpillar(nn.Module): src_time = t - u - t0 src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device) - mask_V = (torch.rand(N, CH, t1 - t0, DV) <= self.proba_flashback).long() + mask_V = ( + torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback + ).long() self.rec_V[:, :, t0:t1] = ( mask_V * V[n, src_head, src_time, dv] + (1 - mask_V) * self.rec_V[:, :, t0:t1] ) - mask_K = (torch.rand(N, CH, t1 - t0, DK) <= self.proba_flashback).long() + mask_K = ( + torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback + ).long() self.rec_K[:, :, t0:t1] = ( mask_K * K[n, src_head, src_time, dk] + (1 - mask_K) * self.rec_K[:, :, t0:t1] -- 2.20.1