######################################################################
+class Folder(nn.Module):
+ def forward(self, x):
+ return x.mean(dim=1)
+
+
+class Unfolder(nn.Module):
+ def __init__(self, T, dim):
+ super().__init__()
+ self.biases = nn.Parameter(torch.randn(T, dim))
+
+ def forward(self, x):
+ return x[:, None, :] + self.biases[None, :, :]
+
+
class Recorder(nn.Module):
def __init__(self, tape):
super().__init__()
model.trunk.insert(L // 2 + 1, Recorder(tape_output))
model.trunk.insert(L // 2, Recorder(tape_input))
- print(model.trunk)
- train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+ mlp = nn.Sequential(
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, 8 * args.dim_model),
+ Folder(),
+ Unfolder(404, 8 * args.dim_model),
+ nn.Linear(8 * args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, args.dim_model),
+ ).to(main_device)
+
+ mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
- with torch.autograd.no_grad():
- model.to(main_device).eval()
- for input in train_input.split(args.batch_size):
+ for n_epoch in range(args.nb_epochs):
+ train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
+ tape_input.clear()
+ tape_output.clear()
+
+ with torch.autograd.no_grad():
+ model.to(main_device).eval()
+ for input in train_input.split(args.batch_size):
+ input = input.to(main_device)
+ output = model(mygpt.BracketedSequence(input)).x
+
+ train_input = torch.cat([bs.x for bs in tape_input], dim=0)
+ train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+ src = zip(
+ train_input.split(args.batch_size), train_targets.split(args.batch_size)
+ )
+ for input, targets in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="train",
+ total=train_input.size(0) // args.batch_size,
+ ):
input = input.to(main_device)
- output = model(mygpt.BracketedSequence(input)).x
+ output = mlp(input)
+ loss = F.mse_loss(output, targets) + output.abs().sum()
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
- train_input = torch.cat([bs.x for bs in tape_input], dim=0)
- train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
+ mlp.optimizer.zero_grad()
+ loss.backward()
+ mlp.optimizer.step()
- print(f"{train_input.size()=} {train_targets.size()=}")
+ log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}")
exit(0)