attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args=None,
):
super().__init__()
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args=None,
):
super().__init__()
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args=None,
):
super().__init__()
warnings.warn("gate dropout", RuntimeWarning)
+ if self.gate_dropout_sync:
+ shape_kill = (N, 1, 1)
+ else:
+ shape_kill = (N, H, R)
+
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
- torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+ == 0
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
- torch.rand(N, H, R, 1, device=G.device) <= self.gate_dropout_proba
+ torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
)
# The coefficient to keep are the complementary
causal=False,
attention_dropout=0.0,
logger=print,
- args,
+ args=None,
):
super().__init__()
len_max=1e5,
attention_layer="kvrec",
logger=print,
- args,
+ args=None,
):
super().__init__()
causal=causal,
attention_dropout=dropout,
logger=logger,
- args,
+ args=args,
)
elif attention_layer == "dumbrec":
return DumbRec(
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args=args,
)
elif attention_layer == "kvrec":
return KVRec(
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args=args,
)
elif attention_layer == "caterpillar":
return Caterpillar(
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
logger=logger,
- args,
+ args=args,
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")