Fixed the size of w_o.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index c810eef..a83107b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -206,6 +206,7 @@ class TaskPicoCLVR(Task):
         for j in range(nb_tokens):
             t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
             input = torch.tensor(t, device = self.device)
+            input = F.pad(input, (0, 1)) # Add the next token, the one to predict
             output = model(input)
             logits = output[0, -1]
             if args.synthesis_sampling:
@@ -333,6 +334,7 @@ class TaskWiki103(Task):
                  for j in range(nb_tokens):
 
                      input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     input = F.pad(input, (0, 1)) # Add the next token, the one to predict
                      output = model(input)
                      logits = output[0, -1]
                      if args.synthesis_sampling:
@@ -498,7 +500,7 @@ for k in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split = 'test'):
             input = input.to(device)
             output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)