From: François Fleuret Date: Tue, 26 Mar 2024 17:45:58 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7195d3207fccf4ea38238bdde50399ea344a695f;p=pytorch.git Update. --- diff --git a/bit_mlp.py b/bit_mlp.py index 85262b7..8fffe7a 100755 --- a/bit_mlp.py +++ b/bit_mlp.py @@ -59,6 +59,8 @@ class QLinear(nn.Module): for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: for linear_layer in [nn.Linear, QLinear]: + # The model + model = nn.Sequential( nn.Flatten(), linear_layer(784, nb_hidden), @@ -72,10 +74,9 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: optimizer = torch.optim.Adam(model.parameters(), lr=lr) - ###################################################################### + # for k in range(nb_epochs): - ############################################ # Train model.train() @@ -93,7 +94,6 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: loss.backward() optimizer.step() - ############################################ # Test model.eval()