Update
[pytorch] / confidence.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12 nb = 100
13 delta = 0.35
14 x = torch.empty(nb).uniform_(0.0, delta)
15 x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
16
17 a = x * math.pi * 2 * 4
18 b = x * math.pi * 2 * 3
19 y = a.sin() + b
20
21 x = x.view(-1, 1)
22 y = y.view(-1, 1)
23
24 ######################################################################
25
26 nh = 400
27
28 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
29                       nn.Dropout(0.25),
30                       nn.Linear(nh, nh), nn.ReLU(),
31                       nn.Dropout(0.25),
32                       nn.Linear(nh, 1))
33
34 model.train(True)
35 criterion = nn.MSELoss()
36 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
37
38 for k in range(10000):
39     loss = criterion(model(x), y)
40     if (k+1)%100 == 0: print(k+1, loss.item())
41     optimizer.zero_grad()
42     loss.backward()
43     optimizer.step()
44
45 ######################################################################
46
47 import matplotlib.pyplot as plt
48
49 fig, ax = plt.subplots()
50
51 u = torch.linspace(0, 1, 101)
52 v = u.view(-1, 1).expand(-1, 25).reshape(-1, 1)
53 v = model(v).reshape(101, -1)
54 mean = v.mean(1)
55 std = v.std(1)
56
57 ax.fill_between(u.numpy(), (mean-std).detach().numpy(), (mean+std).detach().numpy(), color = '#e0e0e0')
58 ax.plot(u.numpy(), mean.detach().numpy(), color = 'red')
59 ax.scatter(x.numpy(), y.numpy())
60
61 plt.show()
62
63 ######################################################################