From 6833683bd343fd687d093d6c47cca8f1909e8b03 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 9 Jan 2024 10:45:55 +0100 Subject: [PATCH] Update. --- README.txt | 4 ++++ mygpt.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/README.txt b/README.txt index abf3490..4f8d8da 100644 --- a/README.txt +++ b/README.txt @@ -1,4 +1,8 @@ +# To run a minimal test task + ./main.py --task=memory --model=4M-C --nb_epochs=5 +# To run the grid experiment with a 37M Caterpillar model. + ./main.py --task=grid --model=37M-C diff --git a/mygpt.py b/mygpt.py index f3c9a93..e7362b7 100755 --- a/mygpt.py +++ b/mygpt.py @@ -771,7 +771,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0 -- 2.20.1