projects
/
mygpt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
6260a35
)
OCDC
author
Francois Fleuret
<francois@fleuret.org>
Thu, 28 Jul 2022 19:53:21 +0000
(21:53 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 28 Jul 2022 19:53:21 +0000
(21:53 +0200)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
7c4e06d
..
212e1a5
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-37,16
+37,14
@@
class PositionalEncoding(nn.Module):
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
- return x + pe
# Let broadcasting to its job
+ return x + pe
##############################
class QKVAttention(nn.Module):
##############################
class QKVAttention(nn.Module):
- def __init__(
- self,
- dim_in, dim_qk, dim_v,
- nb_heads = 1, causal = False, attention_dropout = 0.0
- ):
+ def __init__(self,
+ dim_in, dim_qk, dim_v,
+ nb_heads = 1, causal = False, attention_dropout = 0.0):
super().__init__()
def randw(*d):
super().__init__()
def randw(*d):
@@
-88,7
+86,8
@@
class MyGPT(nn.Module):
def __init__(self,
vocabulary_size,
dim_model, dim_keys, dim_hidden,
def __init__(self,
vocabulary_size,
dim_model, dim_keys, dim_hidden,
- nb_heads, nb_blocks, dropout = 0.):
+ nb_heads, nb_blocks,
+ dropout = 0.0, len_max = 1e5):
super().__init__()
super().__init__()
@@
-97,7
+96,7
@@
class MyGPT(nn.Module):
self.embedding = nn.Sequential(
nn.Embedding(vocabulary_size, dim_model),
nn.Dropout(dropout),
self.embedding = nn.Sequential(
nn.Embedding(vocabulary_size, dim_model),
nn.Dropout(dropout),
- PositionalEncoding(len_max
= 1e5
),
+ PositionalEncoding(len_max),
)
trunk_blocks = [ ]
)
trunk_blocks = [ ]