X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=tinyae.py;h=b4f3aba8ed9fd3c1d23112818049551be2905e21;hp=160878657efa7686c75ad8083b30fbfbb4e9d845;hb=HEAD;hpb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a diff --git a/tinyae.py b/tinyae.py index 1608786..806559e 100755 --- a/tinyae.py +++ b/tinyae.py @@ -55,7 +55,7 @@ def log_string(s): class AutoEncoder(nn.Module): def __init__(self, nb_channels, embedding_dim): - super(AutoEncoder, self).__init__() + super().__init__() self.encoder = nn.Sequential( nn.Conv2d(1, nb_channels, kernel_size=5), # to 24x24 @@ -92,8 +92,8 @@ class AutoEncoder(nn.Module): return self.decoder(z.view(z.size(0), -1, 1, 1)) def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) + x = self.encode(x) + x = self.decode(x) return x @@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std) ###################################################################### -for epoch in range(args.nb_epochs): - acc_loss = 0 +for n_epoch in range(args.nb_epochs): + acc_train_loss = 0 for input in train_input.split(args.batch_size): output = model(input) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + train_loss = F.mse_loss(output, input) optimizer.zero_grad() - loss.backward() + train_loss.backward() optimizer.step() - acc_loss += loss.item() + acc_train_loss += train_loss.detach().item() * input.size(0) - log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss)) + train_loss = acc_train_loss / train_input.size(0) + + log_string(f"train_loss {n_epoch} {train_loss}") ######################################################################