print(f"\n\nSANITY {a**T}\n")
exit(0)
+
+######################################################################
+
+2024 Jan 14 13:39:37 (from mygpt.py)
+
+ epsilon = 0.5
+
+ dropout_head = (
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)
+ .float()
+ )
+
+ dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
+
+ dropout_active = (
+ torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+ ).long()
+
+ dropout_head *= dropout_active
+ dropout_tail *= dropout_active
+
+ G = (
+ G
+ + dropout_head * (1 - epsilon - G.detach())
+ - dropout_tail * G.detach()
+ )