Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jul 2024 15:37:46 +0000 (17:37 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 11 Jul 2024 15:37:46 +0000 (17:37 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 73e7ca2..4cf4d59 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -341,7 +341,7 @@ def one_epoch(model, quiz_machine, local_device=None):
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
-    log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+    log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}")
 
     run_tests(model, quiz_machine, deterministic_synthesis=False)
 
@@ -354,9 +354,6 @@ def one_epoch(model, quiz_machine, local_device=None):
 def standard_validity(logproba):
     l = logproba.sort(dim=-1).values
     return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
-    # warnings.warn("TEST!!!", RuntimeWarning)
-    # print(l.exp())
-    # return (l[:, 0] < math.log(0.99))
 
 
 def valid_c_quizzes(recorded, criteria):
@@ -452,13 +449,9 @@ for k in range(args.nb_gpts):
     model.id = k
     model.TRAINING_LOCK = threading.Lock()
 
-    model.train_w_quizzes = quiz_machine.generate_token_sequences(
-        args.nb_train_samples
-    ).to(device)
+    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
     quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
-    model.test_w_quizzes = quiz_machine.generate_token_sequences(
-        args.nb_test_samples
-    ).to(device)
+    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
     quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
 
     models.append(model)
@@ -532,6 +525,11 @@ if args.dirty_debug:
     nb_new_c_quizzes_for_train = 100
     nb_new_c_quizzes_for_test = 10
 
+    def standard_validity(logproba):
+        l = logproba.sort(dim=-1).values
+        return l[:, 0] < math.log(0.5)
+
+
 ######################################################################
 
 for n_epoch in range(args.nb_epochs):
index 1f1046d..ae14614 100755 (executable)
@@ -327,6 +327,7 @@ class QuizMachine:
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input, log_prefix=None):
+            input = input.to(self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -404,26 +405,29 @@ class QuizMachine:
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
         self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to(self.device)
+        input[-nb:] = fresh_w_quizzes.to("cpu")
 
     ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         with self.LOCK_C_QUIZZES:
             if for_train:
-                self.train_c_quizzes.append(new_c_quizzes)
+                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
             else:
-                self.test_c_quizzes.append(new_c_quizzes)
+                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
 
     ######################################################################
 
     def logproba_of_solutions(self, models, c_quizzes):
-        logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+        logproba = c_quizzes.new_zeros(
+            c_quizzes.size(0), len(models), device=self.device
+        )
 
         for model in models:
             for input, l in zip(
                 c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
             ):
+                input = input.to(self.device)
                 ar_mask = self.make_ar_mask(input)
                 output = model(mygpt.BracketedSequence(input)).x
                 ce = (
@@ -432,7 +436,7 @@ class QuizMachine:
                 )
                 l[:, model.id] = -ce.sum(dim=-1)
 
-        return logproba
+        return logproba.to("cpu")
 
     ###############################################################
 
@@ -561,4 +565,4 @@ class QuizMachine:
             device=self.device,
         )
 
-        return c_quizzes
+        return c_quizzes.to("cpu")