Update. master
authorFrançois Fleuret <francois@fleuret.org>
Thu, 13 Jun 2024 18:05:29 +0000 (20:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 13 Jun 2024 18:05:29 +0000 (20:05 +0200)
redshift.py

index b3507ed..2ed1e52 100755 (executable)
@@ -9,8 +9,10 @@ from torch.nn import functional as F
 
 torch.set_default_dtype(torch.float64)
 
 
 torch.set_default_dtype(torch.float64)
 
+nb_hidden = 5
+hidden_dim = 100
+
 res = 256
 res = 256
-nh = 100
 
 input = torch.cat(
     [
 
 input = torch.cat(
     [
@@ -28,11 +30,10 @@ class Angles(nn.Module):
 
 for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
     for s in [1.0, 10.0]:
 
 for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
     for s in [1.0, 10.0]:
-        layers = [nn.Linear(2, nh), activation()]
-        nb_hidden = 4
-        for k in range(nb_hidden):
-            layers += [nn.Linear(nh, nh), activation()]
-        layers += [nn.Linear(nh, 2)]
+        layers = [nn.Linear(2, hidden_dim), activation()]
+        for k in range(nb_hidden - 1):
+            layers += [nn.Linear(hidden_dim, hidden_dim), activation()]
+        layers += [nn.Linear(hidden_dim, 2)]
         model = nn.Sequential(*layers)
 
         with torch.no_grad():
         model = nn.Sequential(*layers)
 
         with torch.no_grad():