-a = torch.linspace(-1.5, 1.5, 30)
-x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
-y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
-grid = torch.cat((y, x), 2).view(-1, 2)
+def train_model(data):
+ model = nn.Sequential(
+ nn.Linear(2, 100),
+ nn.ReLU(),
+ nn.Linear(100, 2)
+ )
+
+ batch_size, nb_epochs = 100, 1000
+ optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+ criterion = nn.MSELoss()
+
+ for e in range(nb_epochs):
+ acc_loss = 0
+ for input in data.split(batch_size):
+ noise = input.new(input.size()).normal_(0, 0.1)
+ output = model(input + noise)
+ loss = criterion(output, input)
+ acc_loss += loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if (e+1)%100 == 0: print(e+1, acc_loss)
+
+ return model