Update.
[pytorch.git] / tinyae.py
index 1608786..806559e 100755 (executable)
--- a/tinyae.py
+++ b/tinyae.py
@@ -55,7 +55,7 @@ def log_string(s):
 
 class AutoEncoder(nn.Module):
     def __init__(self, nb_channels, embedding_dim):
-        super(AutoEncoder, self).__init__()
+        super().__init__()
 
         self.encoder = nn.Sequential(
             nn.Conv2d(1, nb_channels, kernel_size=5),  # to 24x24
@@ -92,8 +92,8 @@ class AutoEncoder(nn.Module):
         return self.decoder(z.view(z.size(0), -1, 1, 1))
 
     def forward(self, x):
-        x = self.encoder(x)
-        x = self.decoder(x)
+        x = self.encode(x)
+        x = self.decode(x)
         return x
 
 
@@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std)
 
 ######################################################################
 
-for epoch in range(args.nb_epochs):
-    acc_loss = 0
+for n_epoch in range(args.nb_epochs):
+    acc_train_loss = 0
 
     for input in train_input.split(args.batch_size):
         output = model(input)
-        loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
+        train_loss = F.mse_loss(output, input)
 
         optimizer.zero_grad()
-        loss.backward()
+        train_loss.backward()
         optimizer.step()
 
-        acc_loss += loss.item()
+        acc_train_loss += train_loss.detach().item() * input.size(0)
 
-    log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
+    train_loss = acc_train_loss / train_input.size(0)
+
+    log_string(f"train_loss {n_epoch} {train_loss}")
 
 ######################################################################