From 52dfbbcf39af905cd98129a7daf500a87a9824c4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 9 Oct 2024 16:19:00 +0200 Subject: [PATCH] Update. --- attae.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/attae.py b/attae.py index d971b9d..d0f4e6b 100755 --- a/attae.py +++ b/attae.py @@ -337,6 +337,12 @@ class Reasoning(nn.Module): x = self.trunk_joint(x) f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :] + + if hasattr(self, "forced_f") and self.forced_f is not None: + f = self.forced_f.clone() + + self.pred_f = f.clone() + x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim) f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim) x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim) -- 2.39.5