Update.
[pytorch.git] / confidence.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 nb = 100
18 delta = 0.35
19 x = torch.empty(nb).uniform_(0.0, delta)
20 x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
21
22 a = x * math.pi * 2 * 4
23 b = x * math.pi * 2 * 3
24 y = a.sin() + b
25
26 x = x.view(-1, 1)
27 y = y.view(-1, 1)
28
29 ######################################################################
30
31 nh = 400
32
33 model = nn.Sequential(
34     nn.Linear(1, nh),
35     nn.ReLU(),
36     nn.Dropout(0.25),
37     nn.Linear(nh, nh),
38     nn.ReLU(),
39     nn.Dropout(0.25),
40     nn.Linear(nh, 1),
41 )
42
43 model.train(True)
44 criterion = nn.MSELoss()
45 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
46
47 for k in range(10000):
48     loss = criterion(model(x), y)
49     if (k + 1) % 100 == 0:
50         print(k + 1, loss.item())
51     optimizer.zero_grad()
52     loss.backward()
53     optimizer.step()
54
55 ######################################################################
56
57 import matplotlib.pyplot as plt
58
59 fig, ax = plt.subplots()
60
61 u = torch.linspace(0, 1, 101)
62 v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
63 v = model(v).reshape(101, -1)
64 mean = v.mean(1)
65 std = v.std(1)
66
67 ax.fill_between(
68     u.numpy(),
69     (mean - std).detach().numpy(),
70     (mean + std).detach().numpy(),
71     color="#e0e0e0",
72 )
73 ax.plot(u.numpy(), mean.detach().numpy(), color="red")
74 ax.scatter(x.numpy(), y.numpy())
75
76 plt.show()
77
78 ######################################################################