Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 16:47:44 +0000 (18:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 16:47:44 +0000 (18:47 +0200)
quiz_machine.py

index 386969a..332cd86 100755 (executable)
@@ -94,14 +94,14 @@ class QuizMachine:
 
     ######################################################################
 
-    def sigma_for_grids(self, input):
+    def sigma_for_grids(self, input, block_order=(0, 1, 2, 3)):
         l = input.size(1) // 4
         sigma = input.new(input.size())
         r = sigma.view(sigma.size(0), 4, l)
-        r[:, 0] = 0 * l
-        r[:, 1] = 1 * l
-        r[:, 2] = 2 * l
-        r[:, 3] = 3 * l
+        r[:, 0, :] = block_order[0] * l
+        r[:, 1, :] = block_order[1] * l
+        r[:, 2, :] = block_order[2] * l
+        r[:, 3, :] = block_order[3] * l
         r[:, :, 1:] += (
             torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices
         ) + 1