Update. master
authorFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jun 2024 13:53:17 +0000 (15:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 17 Jun 2024 13:53:17 +0000 (15:53 +0200)
main.py
turing.py [new file with mode: 0755]

diff --git a/main.py b/main.py
index 37515b5..3ff64b7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -844,7 +844,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
             input = input.to(device)
 
             bs = model(mygpt.BracketedSequence(input))
-            output_ar = bs.x
+            output = bs.x
 
             loss = F.cross_entropy(output.transpose(1, 2), input)
 
diff --git a/turing.py b/turing.py
new file mode 100755 (executable)
index 0000000..2bcdeeb
--- /dev/null
+++ b/turing.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+
+import torch
+
+
+def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5):
+    next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
+    next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
+    next_move = torch.randint(3, (N, nb_states, nb_symbols))
+
+    all_n = torch.arange(N)
+
+    tape = torch.randint(nb_symbols, (N, tape_size))
+    # position = torch.randint(tape_size, (N,))
+    # state = torch.randint(nb_states, (N,))
+    position = torch.zeros(N, dtype=torch.int64)
+    state = torch.zeros(N, dtype=torch.int64)
+
+    result = []
+
+    for _ in range(nb_iter):
+        result.append(tape.clone())
+        current_symbol = tape[all_n, position]
+        tape[all_n, position] = next_symbol[all_n, state, current_symbol]
+        position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size
+        state = next_state[all_n, state, current_symbol]
+
+    result = torch.cat([x[:, None, :] for x in result], dim=1)
+
+    return result
+
+
+######################################################################
+
+if __name__ == "__main__":
+    print("Basic check.")
+
+    tapes = generate_turing_sequences(1, nb_iter=10)
+
+    for i in range(tapes.size(1)):
+        # print(f"- {i:03d} ------------------------")
+        # for s, h, r in zip(state, position, tape):
+        # print("".join([f"{x}" for x in r]))
+        # print(" " * h + f"^[{s}]")
+        for r in tapes:
+            print("".join([f"{x}" for x in r[i]]))