+class ProblemMemory(Problem):
+ def __init__(self, len_total=25):
+ self.len_total = len_total
+ self.max_len_pattern = 5
+ self.nb_noise_tokens = 10
+ self.start_pattern_token = 0
+ self.end_pattern_token = 1
+ self.start_result_token = 2
+ self.end_result_token = 3
+ self.token_string = "[]<>" + "".join(
+ [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
+ )
+
+ def generate_sequences(self, nb):
+ sequences = (
+ torch.randint(self.nb_noise_tokens, (nb, self.len_total))
+ + self.end_result_token
+ + 1
+ )
+ len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
+ pattern_positions = torch.randint(
+ self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
+ )
+ k = self.len_total - (3 + self.max_len_pattern)
+ for i in range(nb):
+ l = len_patterns[i]
+ j = pattern_positions[i]
+ sequences[i, j] = self.start_pattern_token
+ sequences[i, j + l + 2] = self.end_pattern_token
+ sequences[i, k] = self.start_result_token
+ sequences[i, k + l + 2] = self.end_result_token
+ sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
+
+ j = torch.arange(self.len_total)[None, :]
+ ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
+
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.token_string[x.item()] for x in seq)
+
+