From 16cb07f99cf770fb4e97824f874a68cbddd4c1cf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jul 2023 20:29:54 +0200 Subject: [PATCH] Update. --- problems.py | 1 - tasks.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/problems.py b/problems.py index 78bb64e..5161587 100755 --- a/problems.py +++ b/problems.py @@ -156,4 +156,3 @@ class ProblemAddition(Problem): # for strain, stest in zip(train_seq, test_seq): # s = torch.cat((strain, stest), 0) - diff --git a/tasks.py b/tasks.py index 421aee4..b2f7d7d 100755 --- a/tasks.py +++ b/tasks.py @@ -76,6 +76,7 @@ class Task: import problems + class SandBox(Task): def __init__( self, @@ -1134,8 +1135,8 @@ class RPL(Task): ) if save_attention_image is not None: - ns=torch.randint(self.test_input.size(0),(1,)).item() - input = self.test_input[ns:ns+1].clone() + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 input = input[:, :last].to(self.device) -- 2.39.5