Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 03:31:22 +0000 (05:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 03:31:22 +0000 (05:31 +0200)
main.py
mygpt.py
quiz_machine.py

diff --git a/main.py b/main.py
index ffdd16f..3787e9f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -95,11 +95,11 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--temperature_hot", type=float, default=2)
+parser.add_argument("--temperature_hot", type=float, default=1.25)
 
-parser.add_argument("--temperature_cold", type=float, default=0.75)
+parser.add_argument("--temperature_cold", type=float, default=1.25)
 
-parser.add_argument("--nb_rounds", type=int, default=1)
+parser.add_argument("--nb_rounds", type=int, default=2)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
@@ -645,6 +645,94 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
 
 ######################################################################
+#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+
+def train_auto_encoder():
+    model = mygpt.MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        causal=False,
+        dropout=args.dropout,
+        auto_encoder_dim=64,
+    ).to(main_device)
+
+    test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    for n_epoch in range(args.nb_epochs):
+        train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+        for input in tqdm.tqdm(
+            train_w_quizzes.split(args.batch_size),
+            dynamic_ncols=True,
+            desc="training AE",
+            total=train_w_quizzes.size(0) // args.batch_size,
+        ):
+            model.train()
+            l = input.size(1) // 4
+            input = input[:, -l:].to(main_device)
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.zero_grad()
+
+            z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
+            output = model.decode(z_shape).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_train_loss += loss.item() * input.size(0)
+
+            nb_train_samples += input.size(0)
+
+            loss.backward()
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.step()
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+        log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+
+        filename = f"auto_encoder.pth"
+        torch.save(
+            model.state_dict(),
+            os.path.join(args.result_dir, filename),
+        )
+        log_string(f"wrote {filename}")
+
+        with torch.autograd.no_grad():
+            model.eval()
+            input = test_w_quizzes[:128, -l:]
+            z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
+            logits = model.decode(z_shape).x
+
+            # dist = torch.distributions.categorical.Categorical(logits=logits)
+            # q = dist.sample()
+
+            q = logits.argmax(dim=-1)
+            q = q.reshape(q.size(0) // 2, 2, -1)
+            input = input.reshape(input.size(0) // 2, 2, -1)
+            q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                f"culture_ae_{n_epoch:04d}.png",
+                q,
+            )
+
+    return model
+
+
+# ae = train_auto_encoder()
+
+# exit(0)
+
+######################################################################
+
 
 models = []
 
index 51c0862..9bec09e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -114,6 +114,30 @@ class AddPositionalEncoding(nn.Module):
 ##############################
 
 
+class EncoderHead(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, bs):
+        z = self.fc(bs.x).mean(dim=1)
+        return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, z_shape):
+        z, shape = z_shape
+        y = self.fc(z)[:, None, :].expand(shape)
+        return BracketedSequence(y)
+
+
+##############################
+
+
 class QKVAttention(nn.Module):
     def __init__(
         self,
@@ -232,6 +256,7 @@ class MyGPT(nn.Module):
         nb_blocks,
         causal=False,
         dropout=0.0,
+        auto_encoder_dim=-1,
         len_max=1e5,
     ):
         super().__init__()
@@ -273,6 +298,24 @@ class MyGPT(nn.Module):
                 ),
             ]
 
+        if auto_encoder_dim > 0:
+            self.encoder = nn.Sequential(
+                *(
+                    trunk_blocks[: nb_blocks // 2]
+                    + [EncoderHead(dim_model, auto_encoder_dim)]
+                )
+            )
+
+            self.decoder = nn.Sequential(
+                *(
+                    [
+                        DecoderBottom(auto_encoder_dim, dim_model),
+                        AddPositionalEncoding(len_max),
+                    ]
+                    + trunk_blocks[nb_blocks // 2 :]
+                )
+            )
+
         self.trunk = nn.Sequential(*trunk_blocks)
 
         self.readout = CacheWrapper(
@@ -288,13 +331,22 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
-        # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         return bs
 
+    def encode(self, bs):
+        bs = self.embedding(bs)
+        z = self.encoder(bs)
+        return z
+
+    def decode(self, z_shape):
+        bs = self.decoder(z_shape)
+        bs = self.readout(bs)
+        return bs
+
     def partial_forward(self, bs, start_layer=None, end_layer=None):
         if start_layer is None:
             # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
index a9319c7..7516aed 100755 (executable)
@@ -496,13 +496,17 @@ class QuizMachine:
 
     ######################################################################
 
-    def generate_c_quizzes_simple(
+    def generate_c_quizzes_(
         self,
         nb,
         model_for_generation,
         temperature_hot=1.0,
         temperature_cold=1.0,
     ):
+        warnings.warn(
+            "**************************** simple quiz generation", RuntimeWarning
+        )
+
         c_quizzes = self.problem.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
         c_quizzes = c_quizzes.to(self.device)