+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+