Update.
[mygptrnn.git] / mygpt.py
index f3c9a93..e7362b7 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -771,7 +771,12 @@ class MyGPT(nn.Module):
     ):
         super().__init__()
 
-        assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"}
+        assert attention_layer in {
+            "mha",
+            "dumbrec",
+            "kvrec",
+            "caterpillar",
+        }, f"Unknown attention operator {attention_layer}."
 
         if attention_layer == "caterpillar":
             assert nb_lines % caterpillar_height == 0