X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=a23470b046faa4b4a6a2c0853c09c4c124a6679f;hb=063e25c1e1442c406746a39220f3c3590882cf51;hp=7bf25b55ba4400923215c7dced5a55052ac1488f;hpb=2673ce0d464ebe94c8d81d85dda4dfd2e24ebb22;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 7bf25b5..a23470b 100755 --- a/mygpt.py +++ b/mygpt.py @@ -119,3 +119,18 @@ class MyGPT(nn.Module): return x ###################################################################### + +if __name__ == '__main__': + vocabulary_size = 10 + x = torch.randint(vocabulary_size, (25, 100)) + + model = MyGPT( + vocabulary_size = vocabulary_size, + dim_model = 16, dim_keys = 50, dim_hidden = 100, + nb_heads = 2, nb_blocks = 3, + dropout = 0.1 + ) + + y = model(x) + +######################################################################