From 48eba3979606e07ae60365756433db028e0777ae Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 15 Aug 2024 15:51:59 +0200 Subject: [PATCH] Update. --- main.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 4375985..27404e8 100755 --- a/main.py +++ b/main.py @@ -952,6 +952,20 @@ if args.dirty_debug: ###################################################################### +class Folder(nn.Module): + def forward(self, x): + return x.mean(dim=1) + + +class Unfolder(nn.Module): + def __init__(self, T, dim): + super().__init__() + self.biases = nn.Parameter(torch.randn(T, dim)) + + def forward(self, x): + return x[:, None, :] + self.biases[None, :, :] + + class Recorder(nn.Module): def __init__(self, tape): super().__init__() @@ -969,19 +983,59 @@ if args.test == "mlp": model.trunk.insert(L // 2 + 1, Recorder(tape_output)) model.trunk.insert(L // 2, Recorder(tape_input)) - print(model.trunk) - train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + mlp = nn.Sequential( + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, 8 * args.dim_model), + Folder(), + Unfolder(404, 8 * args.dim_model), + nn.Linear(8 * args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, args.dim_model), + ).to(main_device) + + mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4) - with torch.autograd.no_grad(): - model.to(main_device).eval() - for input in train_input.split(args.batch_size): + for n_epoch in range(args.nb_epochs): + train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + + tape_input.clear() + tape_output.clear() + + with torch.autograd.no_grad(): + model.to(main_device).eval() + for input in train_input.split(args.batch_size): + input = input.to(main_device) + output = model(mygpt.BracketedSequence(input)).x + + train_input = torch.cat([bs.x for bs in tape_input], dim=0) + train_targets = torch.cat([bs.x for bs in tape_output], dim=0) + + nb_train_samples, acc_train_loss = 0, 0.0 + src = zip( + train_input.split(args.batch_size), train_targets.split(args.batch_size) + ) + for input, targets in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="train", + total=train_input.size(0) // args.batch_size, + ): input = input.to(main_device) - output = model(mygpt.BracketedSequence(input)).x + output = mlp(input) + loss = F.mse_loss(output, targets) + output.abs().sum() + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) - train_input = torch.cat([bs.x for bs in tape_input], dim=0) - train_targets = torch.cat([bs.x for bs in tape_output], dim=0) + mlp.optimizer.zero_grad() + loss.backward() + mlp.optimizer.step() - print(f"{train_input.size()=} {train_targets.size()=}") + log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}") exit(0) -- 2.39.5