From: François Fleuret Date: Thu, 13 Jun 2024 18:05:29 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fmaster;p=pytorch.git Update. --- diff --git a/redshift.py b/redshift.py index b3507ed..2ed1e52 100755 --- a/redshift.py +++ b/redshift.py @@ -9,8 +9,10 @@ from torch.nn import functional as F torch.set_default_dtype(torch.float64) +nb_hidden = 5 +hidden_dim = 100 + res = 256 -nh = 100 input = torch.cat( [ @@ -28,11 +30,10 @@ class Angles(nn.Module): for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]: for s in [1.0, 10.0]: - layers = [nn.Linear(2, nh), activation()] - nb_hidden = 4 - for k in range(nb_hidden): - layers += [nn.Linear(nh, nh), activation()] - layers += [nn.Linear(nh, 2)] + layers = [nn.Linear(2, hidden_dim), activation()] + for k in range(nb_hidden - 1): + layers += [nn.Linear(hidden_dim, hidden_dim), activation()] + layers += [nn.Linear(hidden_dim, 2)] model = nn.Sequential(*layers) with torch.no_grad():