projects
/
culture.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
| inline |
side by side
(parent:
61f3811
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 9 Oct 2024 14:19:00 +0000
(16:19 +0200)
committer
François Fleuret
<francois@fleuret.org>
Wed, 9 Oct 2024 14:19:00 +0000
(16:19 +0200)
attae.py
patch
|
blob
|
history
diff --git
a/attae.py
b/attae.py
index
d971b9d
..
d0f4e6b
100755
(executable)
--- a/
attae.py
+++ b/
attae.py
@@
-337,6
+337,12
@@
class Reasoning(nn.Module):
x = self.trunk_joint(x)
f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+
+ if hasattr(self, "forced_f") and self.forced_f is not None:
+ f = self.forced_f.clone()
+
+ self.pred_f = f.clone()
+
x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)