3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
11 from torch.nn import functional as F
13 lr, nb_epochs, batch_size = 1e-1, 10, 100
15 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
17 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 ######################################################################
21 train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
22 train_input = train_set.data.view(-1, 1, 28, 28).float()
23 train_targets = train_set.targets
25 test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
26 test_input = test_set.data.view(-1, 1, 28, 28).float()
27 test_targets = test_set.targets
29 ######################################################################
32 class SomeLeNet(nn.Module):
35 self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
36 self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
37 self.fc1 = nn.Linear(256, 200)
38 self.fc2 = nn.Linear(200, 10)
41 x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
42 x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
43 x = x.view(x.size(0), -1)
44 x = F.relu(self.fc1(x))
49 ######################################################################
53 nb_parameters = sum(p.numel() for p in model.parameters())
55 print(f"nb_parameters {nb_parameters}")
57 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
58 criterion = nn.CrossEntropyLoss()
63 train_input, train_targets = train_input.to(device), train_targets.to(device)
64 test_input, test_targets = test_input.to(device), test_targets.to(device)
66 mu, std = train_input.mean(), train_input.std()
67 train_input.sub_(mu).div_(std)
68 test_input.sub_(mu).div_(std)
70 start_time = time.perf_counter()
72 for k in range(nb_epochs):
75 for input, targets in zip(
76 train_input.split(batch_size), train_targets.split(batch_size)
79 loss = criterion(output, targets)
80 acc_loss += loss.item()
87 for input, targets in zip(
88 test_input.split(batch_size), test_targets.split(batch_size)
90 wta = model(input).argmax(1)
91 nb_test_errors += (wta != targets).long().sum()
92 test_error = nb_test_errors / test_input.size(0)
93 duration = time.perf_counter() - start_time
95 print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
97 ######################################################################