-# Take the origins of the arrows on the part of grid closer than 0.1
-# from the data points
-dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
-origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
+def save_image(data_name, model, data):
+ a = torch.linspace(-1.5, 1.5, 30)
+ x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
+ y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
+ grid = torch.cat((y, x), 2).view(-1, 2)