Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 19:14:40 +0000 (21:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 19:14:40 +0000 (21:14 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 9c8e0bd..1cf31b3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -324,9 +324,10 @@ elif args.problem == "grids":
             nb_threads=args.nb_threads,
             tasks=args.grids_science_tasks,
         )
-        science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples)
+        science_w_quizzes = science_problem.generate_w_quizzes(100)
+
         if not args.resume:
-            problem.save_some_examples(args.result_dir, "science_")
+            science_problem.save_some_examples(args.result_dir, "science_")
 
 
 else:
@@ -454,7 +455,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 def model_transformer_hot(model):
     # model.temperature = args.temperature_hot
-    model.set_noise_injection(5.0, ("ffw", args.nb_blocks // 2))
+    model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
 
 
 def model_transformer_cold(model):
@@ -813,39 +814,34 @@ for k in range(args.nb_gpts):
 current_epoch = 0
 
 if args.resume:
-    try:
-        for model in models:
-            filename = f"gpt_{model.id:03d}.pth"
-
-            try:
-                d = torch.load(os.path.join(args.result_dir, filename))
-                model.load_state_dict(d[0])
-                model.main_test_accuracy = d[1]
-                log_string(f"successfully loaded {filename}")
-            except FileNotFoundError:
-                log_string(f"cannot find {filename}")
-                pass
+    for model in models:
+        filename = f"gpt_{model.id:03d}.pth"
 
         try:
-            filename = "c_quizzes.pth"
-            quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+            d = torch.load(os.path.join(args.result_dir, filename))
+            model.load_state_dict(d[0])
+            model.main_test_accuracy = d[1]
             log_string(f"successfully loaded {filename}")
         except FileNotFoundError:
             log_string(f"cannot find {filename}")
             pass
 
-        try:
-            filename = "state.pth"
-            state = torch.load(os.path.join(args.result_dir, filename))
-            log_string(f"successfully loaded {filename}")
-            current_epoch = state["current_epoch"]
-        except FileNotFoundError:
-            log_string(f"cannot find {filename}")
-            pass
+    try:
+        filename = "c_quizzes.pth"
+        quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+        log_string(f"successfully loaded {filename}")
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
 
-    except:
-        log_string(f"error when loading {filename}.")
-        exit(1)
+    try:
+        filename = "state.pth"
+        state = torch.load(os.path.join(args.result_dir, filename))
+        log_string(f"successfully loaded {filename}")
+        current_epoch = state["current_epoch"]
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
 
 ######################################################################
 
@@ -872,11 +868,6 @@ if args.dirty_debug:
     args.nb_new_c_quizzes_for_train = 100
     args.nb_new_c_quizzes_for_test = 10
 
-    def compute_valid_quizzes(token_logprobas):
-        l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-        return torch.rand(l[:, 0].size(), device=l.device) < 0.5
-
-
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
index 7c51bae..15ed80e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -233,7 +233,9 @@ class NoiseInjector(nn.Module):
 
     def forward(self, x):
         if self.noise_std > 0:
-            x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+            x = x * (
+                1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
+            )
         return x