5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 x = torch.empty(nb).uniform_(0.0, delta)
15 x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
17 a = x * math.pi * 2 * 4
18 b = x * math.pi * 2 * 3
24 ######################################################################
28 model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
29 nn.Linear(nh, nh), nn.ReLU(),
32 criterion = nn.MSELoss()
33 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
35 for k in range(10000):
36 loss = criterion(model(x), y)
37 if (k+1)%100 == 0: print(k+1, loss.item())
42 ######################################################################
44 import matplotlib.pyplot as plt
46 fig, ax = plt.subplots()
47 ax.scatter(x.numpy(), y.numpy())
49 u = torch.linspace(0, 1, 100).view(-1, 1)
50 ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
53 ######################################################################