From 3afcea624963ad2d381c19a7d54bb26e218c5bce Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 13 Jun 2024 20:05:29 +0200 Subject: [PATCH] Update. --- redshift.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/redshift.py b/redshift.py index b3507ed..2ed1e52 100755 --- a/redshift.py +++ b/redshift.py @@ -9,8 +9,10 @@ from torch.nn import functional as F torch.set_default_dtype(torch.float64) +nb_hidden = 5 +hidden_dim = 100 + res = 256 -nh = 100 input = torch.cat( [ @@ -28,11 +30,10 @@ class Angles(nn.Module): for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]: for s in [1.0, 10.0]: - layers = [nn.Linear(2, nh), activation()] - nb_hidden = 4 - for k in range(nb_hidden): - layers += [nn.Linear(nh, nh), activation()] - layers += [nn.Linear(nh, 2)] + layers = [nn.Linear(2, hidden_dim), activation()] + for k in range(nb_hidden - 1): + layers += [nn.Linear(hidden_dim, hidden_dim), activation()] + layers += [nn.Linear(hidden_dim, 2)] model = nn.Sequential(*layers) with torch.no_grad(): -- 2.39.5