next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
next_move = torch.randint(3, (N, nb_states, nb_symbols))
next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
next_move = torch.randint(3, (N, nb_states, nb_symbols))