From 22415499c0a91922e51f9e2cade009fd404351dc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 14 Jun 2024 11:34:41 +0200 Subject: [PATCH 1/3] Update. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 37515b5..3ff64b7 100755 --- a/main.py +++ b/main.py @@ -844,7 +844,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): input = input.to(device) bs = model(mygpt.BracketedSequence(input)) - output_ar = bs.x + output = bs.x loss = F.cross_entropy(output.transpose(1, 2), input) -- 2.20.1 From b6228999b93968b7362b70b1b570e622a954b805 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jun 2024 15:41:51 +0200 Subject: [PATCH 2/3] Update. --- turing.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100755 turing.py diff --git a/turing.py b/turing.py new file mode 100755 index 0000000..66c7f03 --- /dev/null +++ b/turing.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python + +import torch + + +def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size=5): + next_state = torch.randint(nb_states, (N, nb_states, nb_symbols)) + next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols)) + next_move = torch.randint(3, (N, nb_states, nb_symbols)) + + all_n = torch.arange(N) + + tape = torch.randint(nb_symbols, (N, tape_size)) + position = torch.randint(tape_size, (N,)) + state = torch.randint(nb_states, (N,)) + + result = [] + + for _ in range(nb_iter): + result.append(tape) + current_symbol = tape[all_n, position] + tape[all_n, position] = next_symbol[all_n, state, current_symbol] + position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size + state = next_state[all_n, state, current_symbol] + + result = torch.cat([x[:, None, :] for x in result], dim=1) + + return result + + +###################################################################### + +if __name__ == "__main__": + print("Basic check.") + + tapes = generate_turing_sequences(5) + + for i in range(tapes.size(1)): + print(f"- {i:03d} ------------------------") + # for s, h, r in zip(state, position, tape): + # print("".join([f"{x}" for x in r])) + # print(" " * h + f"^[{s}]") + for r in tapes: + print("".join([f"{x}" for x in r[i]])) -- 2.20.1 From cf94b49d085ec05e1053b49b7e796afa3f593a28 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jun 2024 15:53:17 +0200 Subject: [PATCH 3/3] Update. --- turing.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/turing.py b/turing.py index 66c7f03..2bcdeeb 100755 --- a/turing.py +++ b/turing.py @@ -3,7 +3,7 @@ import torch -def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size=5): +def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5): next_state = torch.randint(nb_states, (N, nb_states, nb_symbols)) next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols)) next_move = torch.randint(3, (N, nb_states, nb_symbols)) @@ -11,13 +11,15 @@ def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size all_n = torch.arange(N) tape = torch.randint(nb_symbols, (N, tape_size)) - position = torch.randint(tape_size, (N,)) - state = torch.randint(nb_states, (N,)) + # position = torch.randint(tape_size, (N,)) + # state = torch.randint(nb_states, (N,)) + position = torch.zeros(N, dtype=torch.int64) + state = torch.zeros(N, dtype=torch.int64) result = [] for _ in range(nb_iter): - result.append(tape) + result.append(tape.clone()) current_symbol = tape[all_n, position] tape[all_n, position] = next_symbol[all_n, state, current_symbol] position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size @@ -33,10 +35,10 @@ def generate_turing_sequences(N, nb_iter=5, nb_states=4, nb_symbols=2, tape_size if __name__ == "__main__": print("Basic check.") - tapes = generate_turing_sequences(5) + tapes = generate_turing_sequences(1, nb_iter=10) for i in range(tapes.size(1)): - print(f"- {i:03d} ------------------------") + # print(f"- {i:03d} ------------------------") # for s, h, r in zip(state, position, tape): # print("".join([f"{x}" for x in r])) # print(" " * h + f"^[{s}]") -- 2.20.1