projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
2d33574
..
760a3c6
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-202,7
+202,7
@@
class DumbRec(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-333,7
+333,7
@@
class KVRec(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-487,7
+487,7
@@
class Caterpillar(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-629,15
+629,21
@@
class Caterpillar(nn.Module):
warnings.warn("gate dropout", RuntimeWarning)
warnings.warn("gate dropout", RuntimeWarning)
+ if self.gate_dropout_sync:
+ shape_kill = (N, 1, 1)
+ else:
+ shape_kill = (N, H, R)
+
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
# Pick a point in each of the NxHxR timeline and set this
# entry and the following to 1
kill = (
- torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0
+ torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
+ == 0
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
).cumsum(dim=3)
# Keep these mask for only some of the NxHxR
kill = kill * (
- torch.rand(
N, H, R
, 1, device=G.device) <= self.gate_dropout_proba
+ torch.rand(
*shape_kill
, 1, device=G.device) <= self.gate_dropout_proba
)
# The coefficient to keep are the complementary
)
# The coefficient to keep are the complementary
@@
-715,7
+721,7
@@
class QKVAttention(nn.Module):
causal=False,
attention_dropout=0.0,
logger=print,
causal=False,
attention_dropout=0.0,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-808,7
+814,7
@@
class MyGPT(nn.Module):
len_max=1e5,
attention_layer="kvrec",
logger=print,
len_max=1e5,
attention_layer="kvrec",
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-846,7
+852,7
@@
class MyGPT(nn.Module):
causal=causal,
attention_dropout=dropout,
logger=logger,
causal=causal,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "dumbrec":
return DumbRec(
)
elif attention_layer == "dumbrec":
return DumbRec(
@@
-857,7
+863,7
@@
class MyGPT(nn.Module):
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "kvrec":
return KVRec(
)
elif attention_layer == "kvrec":
return KVRec(
@@
-868,7
+874,7
@@
class MyGPT(nn.Module):
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "caterpillar":
return Caterpillar(
)
elif attention_layer == "caterpillar":
return Caterpillar(
@@
-880,7
+886,7
@@
class MyGPT(nn.Module):
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
logger=logger,
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")