assert dim_model % nb_heads == 0
self.embedding = nn.Sequential(
- nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout)
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
)
self.positional_encoding = VaswaniPositionalEncoding(len_max)
######################################################################
+class MaskedAttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+ self.core = AttentionAE(
+ vocabulary_size * 2,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ )
+
+ def forward(self, x):
+ x = x[:, :, 0] * 2 + x[:, :, 1]
+ return self.core(x)
+
+
+######################################################################
+
+
if __name__ == "__main__":
model = AttentionAE(
vocabulary_size=100,
import torch.multiprocessing as mp
+torch.set_float32_matmul_precision("high")
+
######################################################################
parser = argparse.ArgumentParser(
local_device,
c_quizzes=None,
alien_quiz_machine=None,
- nb_aliens=None,
desc=None,
batch_size=args.batch_size,
):
args.nb_train_samples,
data_structures,
local_device,
- c_quizzes,
- "training",
+ c_quizzes=c_quizzes,
+ desc="training",
):
x_0 = x_0.to(local_device)
mask_generate = mask_generate.to(local_device)
######################################################################
-# import attae
+import attae
models = []
for i in range(args.nb_models):
- model = MyAttentionAE(
- # model = attae.AttentionAE(
+ # model = MyAttentionAE(
+ model = attae.MaskedAttentionAE(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
def save_models(models, suffix=""):
if suffix != "":
suffix = "_" + suffix
+
for model in models:
filename = f"ae_{model.id:03d}{suffix}.pth"
torch.save(
start_time = time.perf_counter()
+ # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
+
multithread_execution(
one_ae_epoch,
[
- (
- model,
- quiz_machine,
- n_epoch,
- None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
- gpu,
- )
+ (model, quiz_machine, n_epoch, c_quizzes, gpu)
for model, gpu in zip(weakest_models, gpus)
],
)