Update.
[pytorch.git] / confidence.py
diff --git a/confidence.py b/confidence.py
new file mode 100755 (executable)
index 0000000..ff4b395
--- /dev/null
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+nb = 100
+delta = 0.35
+x = torch.empty(nb).uniform_(0.0, delta)
+x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
+
+a = x * math.pi * 2 * 4
+b = x * math.pi * 2 * 3
+y = a.sin() + b
+
+x = x.view(-1, 1)
+y = y.view(-1, 1)
+
+######################################################################
+
+nh = 100
+
+model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
+                      nn.Linear(nh, nh), nn.ReLU(),
+                      nn.Linear(nh, 1))
+
+criterion = nn.MSELoss()
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+
+for k in range(10000):
+    loss = criterion(model(x), y)
+    if (k+1)%100 == 0: print(k+1, loss.item())
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+######################################################################
+
+import matplotlib.pyplot as plt
+
+fig, ax = plt.subplots()
+ax.scatter(x.numpy(), y.numpy())
+
+u = torch.linspace(0, 1, 100).view(-1, 1)
+ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
+plt.show()
+
+######################################################################