assert dim_model % nb_heads == 0
self.embedding = nn.Sequential(
- nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Embedding(vocabulary_size, dim_model),
nn.Dropout(dropout),
)
)
x = torch.randint(100, (10, 50))
-
y = model(x)
+
+ with torch.no_grad():
+ model.eval()
+ x = torch.randint(100, (10, 50))
+ y = model(x)
+
+ print(y)
import ffutils
+import attae
+
import mygpt
import sky, grids, quiz_machine
from mygpt import (
- WithResidual,
+ CachedWithResidual,
CacheWrapper,
CachedVaswaniPositionalEncoding,
QKVAttention,
def attention_block(dim_model, dim_keys, nb_heads, dropout):
- return WithResidual(
+ return CachedWithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
),
def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
- return WithResidual(
+ return CachedWithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
self.embedding = CacheWrapper(
nn.Sequential(
- MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+ MultiEmbedding((vocabulary_size, 2), dim_model),
+ nn.Dropout(dropout),
),
)
for i in range(args.nb_models):
model = MyAttentionAE(
- # model = FunctionalAE(
+ # model = attae.AttentionAE(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,