Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 06:40:59 +0000 (09:40 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 06:40:59 +0000 (09:40 +0300)
main.py
mygpt.py
quizz_machine.py
wireworld.py

diff --git a/main.py b/main.py
index 6e5545c..11eb8fd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -434,9 +434,9 @@ def create_c_quizzes(
         for n in range(nb_correct.max() + 1):
             recorded[n].append(new_c_quizzes[nb_correct == n].clone())
 
-        log_string(
-            f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_create}"
-        )
+        nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+
+        log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
 
     # concatenate and shuffle
     for n in recorded.keys():
index 7119c7a..d0fda7e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -201,6 +201,26 @@ class QKVAttention(nn.Module):
 ##############################
 
 
+class NoiseInjector(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.noise_std = 0.0
+
+    def forward(self, x):
+        if self.noise_std > 0:
+            x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+        return x
+
+
+def set_noise_injection(model, noise_std):
+    for m in model.modules():
+        if isinstance(m, NoiseInjector):
+            m.noise_std = noise_std
+
+
+##############################
+
+
 class MyGPT(nn.Module):
     def __init__(
         self,
@@ -228,7 +248,10 @@ class MyGPT(nn.Module):
         for b in range(nb_blocks):
             trunk_blocks += [
                 WithResidual(
-                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        NoiseInjector(),
+                    ),
                     QKVAttention(
                         dim_in=dim_model,
                         dim_qk=dim_keys,
@@ -241,6 +264,7 @@ class MyGPT(nn.Module):
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
+                        NoiseInjector(),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
index 6cad6a1..84bb558 100755 (executable)
@@ -12,6 +12,7 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
+import mygpt
 from mygpt import BracketedSequence
 
 ######################################################################
@@ -20,7 +21,7 @@ from mygpt import BracketedSequence
 class Gang(nn.Module):
     def __init__(self, models, nb_models_for_generation, mode="groupthink"):
         super().__init__()
-        self.models = models
+        self.models = nn.ModuleList(models)
         self.nb_models_for_generation = nb_models_for_generation
         self.mode = mode
 
@@ -383,58 +384,39 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        # bracketing of the temperature to get the target logproba if
-        # min_ave_seq_logproba is not None
+        warnings.warn("noise injection", RuntimeWarning)
+        temperature = 1
+        noise_std = torch.rand(1).item()
+        self.logger(f"{noise_std=}")
+        mygpt.set_noise_injection(model_for_generation, noise_std)
 
-        temperature = 2
-        d_temperature = 1 / 3
-
-        while True:
-            seq_logproba[...] = 0
-
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_prompt,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=False,
-                # progress_bar_desc="sampling c_quizzes",
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_prompt,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=False,
+            # progress_bar_desc="sampling c_quizzes",
+            device=self.device,
+        )
 
-            ave_seq_logproba = seq_logproba.mean()
+        ave_seq_logproba = seq_logproba.mean()
 
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                # progress_bar_desc="sampling c_quizzes",
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=ar_mask_solve,
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=True,
+            # progress_bar_desc="sampling c_quizzes",
+            device=self.device,
+        )
 
-            # If we do not have target logprobs, get out now
-            if min_ave_seq_logproba is None:
-                break
-
-            # Oh man that's ugly
-            if ave_seq_logproba < min_ave_seq_logproba:
-                if d_temperature > 0:
-                    d_temperature *= -1 / 3
-                temperature += d_temperature
-            elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
-                if d_temperature < 0:
-                    d_temperature *= -1 / 3
-                temperature += d_temperature
-            else:
-                break
-
-            self.logger(f"changing temperature to {temperature}")
+        mygpt.set_noise_injection(model_for_generation, 0.0)
 
         return c_quizzes, seq_logproba.mean()
 
index 65b12ad..8257cad 100755 (executable)
@@ -62,9 +62,10 @@ class Wireworld(problem.Problem):
 
     def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
+        nb_frames = (self.nb_iterations - 1) * self.speed + 1
 
         result = torch.full(
-            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            (nb * 4, nb_frames, self.height, self.width),
             self.token_empty,
         )
 
@@ -116,8 +117,8 @@ class Wireworld(problem.Problem):
                         result[n, 0, i + vi, j + vj] = self.token_tail
                         break
 
-                if torch.rand(1) < 0.75:
-                    break
+                if torch.rand(1) < 0.75:
+                break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
@@ -130,7 +131,10 @@ class Wireworld(problem.Problem):
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations * self.speed - 1):
+        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+        valid = nb_heads > 0
+
+        for l in range(nb_frames - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@ -153,6 +157,13 @@ class Wireworld(problem.Problem):
                     + (1 - mask_1_or_2_heads) * self.token_conductor
                 )
             )
+            pred_nb_heads = nb_heads
+            nb_heads = (
+                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+            )
+            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+        result = result[valid]
 
         result = result[
             :, torch.arange(self.nb_iterations, device=result.device) * self.speed