Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:26:39 +0000 (15:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:26:39 +0000 (15:26 +0200)
main.py

diff --git a/main.py b/main.py
index edc366a..6b137bf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -658,7 +658,7 @@ for i in range(args.nb_models):
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
         dropout=args.dropout,
-    ).to(main_device)
+    )
 
     # model = torch.compile(model)
 
@@ -666,9 +666,6 @@ for i in range(args.nb_models):
     model.test_accuracy = 0.0
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
-    model.to(main_device).train()
-    optimizer_to(model.optimizer, main_device)
-
     models.append(model)
 
 ######################################################################