######################################################################
-#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
def train_auto_encoder():
nb_blocks=args.nb_blocks,
causal=False,
dropout=args.dropout,
- auto_encoder_dim=64,
).to(main_device)
+ model.make_auto_encoder(auto_encoder_dim=64)
+
test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
nb_blocks,
causal=False,
dropout=0.0,
- auto_encoder_dim=-1,
len_max=1e5,
):
super().__init__()
),
]
- if auto_encoder_dim > 0:
- self.encoder = nn.Sequential(
- *(
- trunk_blocks[: nb_blocks // 2]
- + [EncoderHead(dim_model, auto_encoder_dim)]
- )
- )
-
- self.decoder = nn.Sequential(
- *(
- [
- DecoderBottom(auto_encoder_dim, dim_model),
- AddPositionalEncoding(len_max),
- ]
- + trunk_blocks[nb_blocks // 2 :]
- )
- )
-
self.trunk = nn.Sequential(*trunk_blocks)
self.readout = CacheWrapper(
bs = self.readout(bs)
return bs
+ def make_auto_encoder(self, auto_encoder_dim):
+ self.encoder = nn.Sequential(
+ *(
+ trunk_blocks[: nb_blocks // 2]
+ + [EncoderHead(dim_model, auto_encoder_dim)]
+ )
+ )
+
+ self.decoder = nn.Sequential(
+ *(
+ [
+ DecoderBottom(auto_encoder_dim, dim_model),
+ AddPositionalEncoding(len_max),
+ ]
+ + trunk_blocks[nb_blocks // 2 :]
+ )
+ )
+
def encode(self, bs):
bs = self.embedding(bs)
z = self.encoder(bs)