5 import torch, torchvision
8 from torch.nn import functional as F
10 torch.set_default_dtype(torch.float64)
17 torch.linspace(-1, 1, res)[None, :, None].expand(res, res, 1),
18 torch.linspace(-1, 1, res)[:, None, None].expand(res, res, 1),
24 class Angles(nn.Module):
26 return x.clamp(min=-0.5, max=0.5)
29 for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
31 layers = [nn.Linear(2, nh), activation()]
33 for k in range(nb_hidden):
34 layers += [nn.Linear(nh, nh), activation()]
35 layers += [nn.Linear(nh, 2)]
36 model = nn.Sequential(*layers)
39 for p in model.parameters():
44 img = (output[:, 1] - output[:, 0]).reshape(1, 1, res, res)
46 img = (img - img.mean()) / (1 * img.std())
48 img = img.clamp(min=-1, max=1)
52 (1 + img).clamp(max=1),
53 (1 - img.abs()).clamp(min=0),
54 (1 - img).clamp(max=1),
62 nn.Softplus: "softplus",
66 torchvision.utils.save_image(img, f"result-{name_activation}-{s}.png")