From: François Fleuret Date: Sun, 23 Jul 2023 09:53:52 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=291c38d093894d46fba6eb45f82e5b65a2a1cb8b;p=culture.git Update. --- diff --git a/graph.py b/graph.py index bd80187..2c7caf8 100755 --- a/graph.py +++ b/graph.py @@ -59,7 +59,7 @@ def save_attention_image( ctx.set_line_width(0.25) for d in range(len(attention_matrices)): - at = attention_matrices[d] + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() diff --git a/tasks.py b/tasks.py index 234e780..42d9126 100755 --- a/tasks.py +++ b/tasks.py @@ -1285,7 +1285,7 @@ class RPL(Task): if save_attention_image is not None: input = self.test_input[:1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 - input = input[:, :last] + input = input[:, :last].to(self.device) with torch.autograd.no_grad(): t = model.training