From: François Fleuret Date: Tue, 9 Jan 2024 09:45:55 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6833683bd343fd687d093d6c47cca8f1909e8b03;p=mygptrnn.git Update. --- diff --git a/README.txt b/README.txt index abf3490..4f8d8da 100644 --- a/README.txt +++ b/README.txt @@ -1,4 +1,8 @@ +# To run a minimal test task + ./main.py --task=memory --model=4M-C --nb_epochs=5 +# To run the grid experiment with a 37M Caterpillar model. + ./main.py --task=grid --model=37M-C diff --git a/mygpt.py b/mygpt.py index f3c9a93..e7362b7 100755 --- a/mygpt.py +++ b/mygpt.py @@ -771,7 +771,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0