Update.
[pytorch.git] / tinymnist.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import time, os
9 import torch, torchvision
10 from torch import nn
11 from torch.nn import functional as F
12
13 lr, nb_epochs, batch_size = 1e-1, 10, 100
14
15 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
16
17 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
19 ######################################################################
20
21 train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
22 train_input = train_set.data.view(-1, 1, 28, 28).float()
23 train_targets = train_set.targets
24
25 test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
26 test_input = test_set.data.view(-1, 1, 28, 28).float()
27 test_targets = test_set.targets
28
29 ######################################################################
30
31
32 class SomeLeNet(nn.Module):
33     def __init__(self):
34         super().__init__()
35         self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
36         self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
37         self.fc1 = nn.Linear(256, 200)
38         self.fc2 = nn.Linear(200, 10)
39
40     def forward(self, x):
41         x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
42         x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
43         x = x.view(x.size(0), -1)
44         x = F.relu(self.fc1(x))
45         x = self.fc2(x)
46         return x
47
48
49 ######################################################################
50
51 model = SomeLeNet()
52
53 nb_parameters = sum(p.numel() for p in model.parameters())
54
55 print(f"nb_parameters {nb_parameters}")
56
57 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
58 criterion = nn.CrossEntropyLoss()
59
60 model.to(device)
61 criterion.to(device)
62
63 train_input, train_targets = train_input.to(device), train_targets.to(device)
64 test_input, test_targets = test_input.to(device), test_targets.to(device)
65
66 mu, std = train_input.mean(), train_input.std()
67 train_input.sub_(mu).div_(std)
68 test_input.sub_(mu).div_(std)
69
70 start_time = time.perf_counter()
71
72 for k in range(nb_epochs):
73     acc_loss = 0.0
74
75     for input, targets in zip(
76         train_input.split(batch_size), train_targets.split(batch_size)
77     ):
78         output = model(input)
79         loss = criterion(output, targets)
80         acc_loss += loss.item()
81
82         optimizer.zero_grad()
83         loss.backward()
84         optimizer.step()
85
86     nb_test_errors = 0
87     for input, targets in zip(
88         test_input.split(batch_size), test_targets.split(batch_size)
89     ):
90         wta = model(input).argmax(1)
91         nb_test_errors += (wta != targets).long().sum()
92     test_error = nb_test_errors / test_input.size(0)
93     duration = time.perf_counter() - start_time
94
95     print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
96
97 ######################################################################