Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 21:06:40 +0000 (23:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 21:06:40 +0000 (23:06 +0200)
main.py

diff --git a/main.py b/main.py
index 8e06bb2..0b9a86e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -446,7 +446,7 @@ c_quizzes_procedure = [
     (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
     (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
     (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
-    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
 ######################################################################
@@ -838,7 +838,15 @@ if args.test == "func":
 
             output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
             dist = torch.distributions.categorical.Categorical(logits=output)
-            input[:, 3 * L :] = dist.sample()
+            input[:, 3 * L + 1 :] = dist.sample()[:, 1:]
+
+            problem.save_quizzes_as_image(
+                args.result_dir,
+                f"thinker_prediction_{n_epoch:04d}.png",
+                quizzes=input,
+                # predicted_parts=predicted_parts,
+                # correct_parts=correct_parts,
+            )
 
 
 ######################################################################