From: François Fleuret Date: Wed, 9 Oct 2024 14:19:00 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=52dfbbcf39af905cd98129a7daf500a87a9824c4;p=culture.git Update. --- diff --git a/attae.py b/attae.py index d971b9d..d0f4e6b 100755 --- a/attae.py +++ b/attae.py @@ -337,6 +337,12 @@ class Reasoning(nn.Module): x = self.trunk_joint(x) f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :] + + if hasattr(self, "forced_f") and self.forced_f is not None: + f = self.forced_f.clone() + + self.pred_f = f.clone() + x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim) f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim) x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)