X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=bd870bc67ec1c8895abe4cd8c81d2b113e4666f9;hb=5c298b53859b4d97aa85331034af952aae3b0c05;hp=b885e218be6704cd86afed18966e5609e9873369;hpb=42831bd654d030b71bca88578d041279018f836c;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index b885e21..bd870bc 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops