From: François Fleuret Date: Fri, 19 Jan 2024 13:06:05 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ec5bad2e7911bdf9b7851342a3ee007d41b80963;p=mygptrnn.git Update. --- diff --git a/mygpt.py b/mygpt.py index 2d33574..0414bb6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -202,7 +202,7 @@ class DumbRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -333,7 +333,7 @@ class KVRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -487,7 +487,7 @@ class Caterpillar(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -715,7 +715,7 @@ class QKVAttention(nn.Module): causal=False, attention_dropout=0.0, logger=print, - args, + args=None, ): super().__init__() @@ -808,7 +808,7 @@ class MyGPT(nn.Module): len_max=1e5, attention_layer="kvrec", logger=print, - args, + args=None, ): super().__init__() @@ -846,7 +846,7 @@ class MyGPT(nn.Module): causal=causal, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "dumbrec": return DumbRec( @@ -857,7 +857,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "kvrec": return KVRec( @@ -868,7 +868,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -880,7 +880,7 @@ class MyGPT(nn.Module): caterpillar_height=self.caterpillar_height, attention_dropout=dropout, logger=logger, - args, + args=args, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")