From ec5bad2e7911bdf9b7851342a3ee007d41b80963 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jan 2024 14:06:05 +0100 Subject: [PATCH] Update. --- mygpt.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mygpt.py b/mygpt.py index 2d33574..0414bb6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -202,7 +202,7 @@ class DumbRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -333,7 +333,7 @@ class KVRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -487,7 +487,7 @@ class Caterpillar(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - args, + args=None, ): super().__init__() @@ -715,7 +715,7 @@ class QKVAttention(nn.Module): causal=False, attention_dropout=0.0, logger=print, - args, + args=None, ): super().__init__() @@ -808,7 +808,7 @@ class MyGPT(nn.Module): len_max=1e5, attention_layer="kvrec", logger=print, - args, + args=None, ): super().__init__() @@ -846,7 +846,7 @@ class MyGPT(nn.Module): causal=causal, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "dumbrec": return DumbRec( @@ -857,7 +857,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "kvrec": return KVRec( @@ -868,7 +868,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - args, + args=args, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -880,7 +880,7 @@ class MyGPT(nn.Module): caterpillar_height=self.caterpillar_height, attention_dropout=dropout, logger=logger, - args, + args=args, ) else: raise ValueError(f"Unknown attention type {attention_layer}.") -- 2.39.5