projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
b885e21
..
95e5527
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-10,6
+10,8
@@
# with a caching mechanism for keys and values to avoid a O(N^3) cost
# for auto-regression.
# with a caching mechanism for keys and values to avoid a O(N^3) cost
# for auto-regression.
+# This implementation is equipped with RNN layers to replace the MHA
+
import math, warnings
import torch, einops
import math, warnings
import torch, einops
@@
-772,7
+774,6
@@
class MyGPT(nn.Module):
nb_blocks,
nb_lines=None,
caterpillar_height=None,
nb_blocks,
nb_lines=None,
caterpillar_height=None,
- dim_rec_v=-1,
causal=False,
dropout=0.0,
len_max=1e5,
causal=False,
dropout=0.0,
len_max=1e5,
@@
-818,7
+819,7
@@
class MyGPT(nn.Module):
return DumbRec(
dim_model=dim_model,
dim_qk=dim_keys,
return DumbRec(
dim_model=dim_model,
dim_qk=dim_keys,
- dim_v=dim_
rec_v
,
+ dim_v=dim_
model // nb_heads
,
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
@@
-827,7
+828,7
@@
class MyGPT(nn.Module):
return KVRec(
dim_model=dim_model,
dim_qk=dim_keys,
return KVRec(
dim_model=dim_model,
dim_qk=dim_keys,
- dim_v=dim_
rec_v
,
+ dim_v=dim_
model // nb_heads
,
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
@@
-836,7
+837,7
@@
class MyGPT(nn.Module):
return Caterpillar(
dim_model=dim_model,
dim_qk=dim_keys,
return Caterpillar(
dim_model=dim_model,
dim_qk=dim_keys,
- dim_v=dim_
rec_v
,
+ dim_v=dim_
model // nb_heads
,
nb_heads=nb_heads,
caterpillar_length=self.caterpillar_length,
caterpillar_height=self.caterpillar_height,
nb_heads=nb_heads,
caterpillar_length=self.caterpillar_length,
caterpillar_height=self.caterpillar_height,