projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c45d89e
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 19 Jan 2024 13:06:05 +0000
(14:06 +0100)
committer
François Fleuret
<francois@fleuret.org>
Fri, 19 Jan 2024 13:06:05 +0000
(14:06 +0100)
mygpt.py
patch
|
blob
|
history
diff --git
a/mygpt.py
b/mygpt.py
index
2d33574
..
0414bb6
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-202,7
+202,7
@@
class DumbRec(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-333,7
+333,7
@@
class KVRec(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-487,7
+487,7
@@
class Caterpillar(nn.Module):
attention_dropout=0.0,
len_max=1e5,
logger=print,
attention_dropout=0.0,
len_max=1e5,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-715,7
+715,7
@@
class QKVAttention(nn.Module):
causal=False,
attention_dropout=0.0,
logger=print,
causal=False,
attention_dropout=0.0,
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-808,7
+808,7
@@
class MyGPT(nn.Module):
len_max=1e5,
attention_layer="kvrec",
logger=print,
len_max=1e5,
attention_layer="kvrec",
logger=print,
- args,
+ args
=None
,
):
super().__init__()
):
super().__init__()
@@
-846,7
+846,7
@@
class MyGPT(nn.Module):
causal=causal,
attention_dropout=dropout,
logger=logger,
causal=causal,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "dumbrec":
return DumbRec(
)
elif attention_layer == "dumbrec":
return DumbRec(
@@
-857,7
+857,7
@@
class MyGPT(nn.Module):
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "kvrec":
return KVRec(
)
elif attention_layer == "kvrec":
return KVRec(
@@
-868,7
+868,7
@@
class MyGPT(nn.Module):
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
nb_lines=nb_lines,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
elif attention_layer == "caterpillar":
return Caterpillar(
)
elif attention_layer == "caterpillar":
return Caterpillar(
@@
-880,7
+880,7
@@
class MyGPT(nn.Module):
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
logger=logger,
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
logger=logger,
- args,
+ args
=args
,
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")