Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 13:51:59 +0000 (15:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 13:51:59 +0000 (15:51 +0200)
main.py

diff --git a/main.py b/main.py
index 4375985..27404e8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -952,6 +952,20 @@ if args.dirty_debug:
 ######################################################################
 
 
+class Folder(nn.Module):
+    def forward(self, x):
+        return x.mean(dim=1)
+
+
+class Unfolder(nn.Module):
+    def __init__(self, T, dim):
+        super().__init__()
+        self.biases = nn.Parameter(torch.randn(T, dim))
+
+    def forward(self, x):
+        return x[:, None, :] + self.biases[None, :, :]
+
+
 class Recorder(nn.Module):
     def __init__(self, tape):
         super().__init__()
@@ -969,19 +983,59 @@ if args.test == "mlp":
     model.trunk.insert(L // 2 + 1, Recorder(tape_output))
     model.trunk.insert(L // 2, Recorder(tape_input))
 
-    print(model.trunk)
-    train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+    mlp = nn.Sequential(
+        nn.Linear(args.dim_model, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, 8 * args.dim_model),
+        Folder(),
+        Unfolder(404, 8 * args.dim_model),
+        nn.Linear(8 * args.dim_model, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, args.dim_model),
+        nn.ReLU(),
+        nn.Linear(args.dim_model, args.dim_model),
+    ).to(main_device)
+
+    mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
 
-    with torch.autograd.no_grad():
-        model.to(main_device).eval()
-        for input in train_input.split(args.batch_size):
+    for n_epoch in range(args.nb_epochs):
+        train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
+        tape_input.clear()
+        tape_output.clear()
+
+        with torch.autograd.no_grad():
+            model.to(main_device).eval()
+            for input in train_input.split(args.batch_size):
+                input = input.to(main_device)
+                output = model(mygpt.BracketedSequence(input)).x
+
+        train_input = torch.cat([bs.x for bs in tape_input], dim=0)
+        train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+
+        nb_train_samples, acc_train_loss = 0, 0.0
+        src = zip(
+            train_input.split(args.batch_size), train_targets.split(args.batch_size)
+        )
+        for input, targets in tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc="train",
+            total=train_input.size(0) // args.batch_size,
+        ):
             input = input.to(main_device)
-            output = model(mygpt.BracketedSequence(input)).x
+            output = mlp(input)
+            loss = F.mse_loss(output, targets) + output.abs().sum()
+            acc_train_loss += loss.item() * input.size(0)
+            nb_train_samples += input.size(0)
 
-    train_input = torch.cat([bs.x for bs in tape_input], dim=0)
-    train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+            mlp.optimizer.zero_grad()
+            loss.backward()
+            mlp.optimizer.step()
 
-    print(f"{train_input.size()=} {train_targets.size()=}")
+        log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}")
 
     exit(0)