# Make a list of strings from a tensor
def detensorize(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+ def id2token(t):
+ try:
+ return self.id2token[t.item()]
+ except KeyError:
+ return "?"
+
+ return [" ".join([id2token(t) for t in r]) for r in x]
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
def compute_nb_correct(input):
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+
ar_mask = (result != input).long()
+ result *= 1 - ar_mask
+
masked_inplace_autoregression(
model,
self.batch_size,
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
- # for n in range(result.size(0)):
- # logger(
- # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
- # )
+ for n in range(result.size(0)):
+ logger(
+ f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+
+ result *= 1 - ar_mask
masked_inplace_autoregression(
model,
# Make a list of strings from a tensor
def tensor2str(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+ def id2token(t):
+ try:
+ return self.id2token[t.item()]
+ except KeyError:
+ return "?"
+
+ return [" ".join([id2token(t) for t in r]) for r in x]
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its