X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=confidence.py;fp=confidence.py;h=ff4b395e8b0b96df1ebf0759cc580f802500fb81;hb=4469498b31c1fb90cb2b1202dbaf86be0f2d18b0;hp=0000000000000000000000000000000000000000;hpb=99fab4ddc7ee5fedf7a898a9263e2c271ea7d721;p=pytorch.git diff --git a/confidence.py b/confidence.py new file mode 100755 index 0000000..ff4b395 --- /dev/null +++ b/confidence.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +nb = 100 +delta = 0.35 +x = torch.empty(nb).uniform_(0.0, delta) +x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta) + +a = x * math.pi * 2 * 4 +b = x * math.pi * 2 * 3 +y = a.sin() + b + +x = x.view(-1, 1) +y = y.view(-1, 1) + +###################################################################### + +nh = 100 + +model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(), + nn.Linear(nh, nh), nn.ReLU(), + nn.Linear(nh, 1)) + +criterion = nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) + +for k in range(10000): + loss = criterion(model(x), y) + if (k+1)%100 == 0: print(k+1, loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +###################################################################### + +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.scatter(x.numpy(), y.numpy()) + +u = torch.linspace(0, 1, 100).view(-1, 1) +ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red') +plt.show() + +######################################################################