Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2024 22:01:12 +0000 (23:01 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jan 2024 22:01:12 +0000 (23:01 +0100)
main.py
mygpt.py

diff --git a/main.py b/main.py
index cae20f8..74e1d6c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -107,8 +107,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None)
 
 parser.add_argument("--rho", type=float, default=0.0)
 
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
 parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
@@ -332,7 +330,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 32,
         "nb_heads": 2,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "17K-C": {
@@ -343,7 +340,6 @@ default_model_args = {
         "nb_heads": 2,
         "nb_lines": 16,
         "caterpillar_height": 4,
-        "dim_rec_v": 16,
         "nb_blocks": 2,
     },
     "4M": {
@@ -352,7 +348,6 @@ default_model_args = {
         "dim_keys": 32,
         "dim_hidden": 1024,
         "nb_heads": 4,
-        "dim_rec_v": 64,
         "nb_blocks": 6,
     },
     "4M-C": {
@@ -363,7 +358,6 @@ default_model_args = {
         "nb_heads": 4,
         "nb_lines": 32,
         "caterpillar_height": 4,
-        "dim_rec_v": 64,  # dim_model / nb_heads
         "nb_blocks": 6,
     },
     "37M": {
@@ -372,7 +366,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "37M-C": {
@@ -383,7 +376,6 @@ default_model_args = {
         "nb_heads": 8,
         "nb_lines": 256,
         "caterpillar_height": 32,
-        "dim_rec_v": 64,
         "nb_blocks": 12,
     },
     "122M": {
@@ -392,7 +384,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "122M-C": {
@@ -402,7 +393,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 96,
         "nb_blocks": 24,
     },
     "352M": {
@@ -411,7 +401,6 @@ default_model_args = {
         "dim_keys": 64,
         "dim_hidden": 2048,
         "nb_heads": 8,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
     "352M-C": {
@@ -421,7 +410,6 @@ default_model_args = {
         "dim_hidden": 2048,
         "nb_heads": 8,
         "nb_lines": 128,
-        "dim_rec_v": 128,
         "nb_blocks": 48,
     },
 }
@@ -736,7 +724,6 @@ model = mygpt.MyGPT(
     nb_heads=args.nb_heads,
     nb_lines=args.nb_lines,
     caterpillar_height=args.caterpillar_height,
-    dim_rec_v=args.dim_rec_v,
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
index bd870bc..95e5527 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -774,7 +774,6 @@ class MyGPT(nn.Module):
         nb_blocks,
         nb_lines=None,
         caterpillar_height=None,
-        dim_rec_v=-1,
         causal=False,
         dropout=0.0,
         len_max=1e5,
@@ -820,7 +819,7 @@ class MyGPT(nn.Module):
                 return DumbRec(
                     dim_model=dim_model,
                     dim_qk=dim_keys,
-                    dim_v=dim_rec_v,
+                    dim_v=dim_model // nb_heads,
                     nb_heads=nb_heads,
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
@@ -829,7 +828,7 @@ class MyGPT(nn.Module):
                 return KVRec(
                     dim_model=dim_model,
                     dim_qk=dim_keys,
-                    dim_v=dim_rec_v,
+                    dim_v=dim_model // nb_heads,
                     nb_heads=nb_heads,
                     nb_lines=nb_lines,
                     attention_dropout=dropout,
@@ -838,7 +837,7 @@ class MyGPT(nn.Module):
                 return Caterpillar(
                     dim_model=dim_model,
                     dim_qk=dim_keys,
-                    dim_v=dim_rec_v,
+                    dim_v=dim_model // nb_heads,
                     nb_heads=nb_heads,
                     caterpillar_length=self.caterpillar_length,
                     caterpillar_height=self.caterpillar_height,