From 7195d3207fccf4ea38238bdde50399ea344a695f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 26 Mar 2024 18:45:58 +0100 Subject: [PATCH] Update. --- bit_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bit_mlp.py b/bit_mlp.py index 85262b7..8fffe7a 100755 --- a/bit_mlp.py +++ b/bit_mlp.py @@ -59,6 +59,8 @@ class QLinear(nn.Module): for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: for linear_layer in [nn.Linear, QLinear]: + # The model + model = nn.Sequential( nn.Flatten(), linear_layer(784, nb_hidden), @@ -72,10 +74,9 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: optimizer = torch.optim.Adam(model.parameters(), lr=lr) - ###################################################################### + # for k in range(nb_epochs): - ############################################ # Train model.train() @@ -93,7 +94,6 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: loss.backward() optimizer.step() - ############################################ # Test model.eval() -- 2.39.5