Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 14:19:00 +0000 (16:19 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 14:19:00 +0000 (16:19 +0200)
attae.py

index d971b9d..d0f4e6b 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -337,6 +337,12 @@ class Reasoning(nn.Module):
         x = self.trunk_joint(x)
 
         f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+
+        if hasattr(self, "forced_f") and self.forced_f is not None:
+            f = self.forced_f.clone()
+
+        self.pred_f = f.clone()
+
         x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
         f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
         x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)