From 75e1ddcb8de30a4a7be16c80c4f258da662837a6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 9 Jan 2024 23:01:12 +0100 Subject: [PATCH] Update. --- main.py | 13 ------------- mygpt.py | 7 +++---- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index cae20f8..74e1d6c 100755 --- a/main.py +++ b/main.py @@ -107,8 +107,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None) parser.add_argument("--rho", type=float, default=0.0) -parser.add_argument("--dim_rec_v", type=int, default=None) - parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) @@ -332,7 +330,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 32, "nb_heads": 2, - "dim_rec_v": 16, "nb_blocks": 2, }, "17K-C": { @@ -343,7 +340,6 @@ default_model_args = { "nb_heads": 2, "nb_lines": 16, "caterpillar_height": 4, - "dim_rec_v": 16, "nb_blocks": 2, }, "4M": { @@ -352,7 +348,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 1024, "nb_heads": 4, - "dim_rec_v": 64, "nb_blocks": 6, }, "4M-C": { @@ -363,7 +358,6 @@ default_model_args = { "nb_heads": 4, "nb_lines": 32, "caterpillar_height": 4, - "dim_rec_v": 64, # dim_model / nb_heads "nb_blocks": 6, }, "37M": { @@ -372,7 +366,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 64, "nb_blocks": 12, }, "37M-C": { @@ -383,7 +376,6 @@ default_model_args = { "nb_heads": 8, "nb_lines": 256, "caterpillar_height": 32, - "dim_rec_v": 64, "nb_blocks": 12, }, "122M": { @@ -392,7 +384,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 96, "nb_blocks": 24, }, "122M-C": { @@ -402,7 +393,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 96, "nb_blocks": 24, }, "352M": { @@ -411,7 +401,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 128, "nb_blocks": 48, }, "352M-C": { @@ -421,7 +410,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 128, "nb_blocks": 48, }, } @@ -736,7 +724,6 @@ model = mygpt.MyGPT( nb_heads=args.nb_heads, nb_lines=args.nb_lines, caterpillar_height=args.caterpillar_height, - dim_rec_v=args.dim_rec_v, nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, diff --git a/mygpt.py b/mygpt.py index bd870bc..95e5527 100755 --- a/mygpt.py +++ b/mygpt.py @@ -774,7 +774,6 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, @@ -820,7 +819,7 @@ class MyGPT(nn.Module): return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -829,7 +828,7 @@ class MyGPT(nn.Module): return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -838,7 +837,7 @@ class MyGPT(nn.Module): return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, -- 2.39.5