nb_total = input.size(0)
nb_correct = (input == result).long().min(1).values.sum()
+ #######################################################################
+ # Comput predicted vs. true variable values
+
values_input = expr.extract_results([self.seq2str(s) for s in input])
max_input = max([max(x.values()) for x in values_input])
values_result = expr.extract_results([self.seq2str(s) for s in result])
[-1 if len(x) == 0 else max(x.values()) for x in values_result]
)
- nb_missing, nb_predicted = torch.zeros(max_input + 1), torch.zeros(
- max_input + 1, max_result + 1
- )
+ nb_missing = torch.zeros(max_input + 1)
+ nb_predicted = torch.zeros(max_input + 1, max_result + 1)
+
for i, r in zip(values_input, values_result):
for n, vi in i.items():
vr = r.get(n)
nb_missing[vi] += 1
else:
nb_predicted[vi, vr] += 1
+ ######################################################################
return nb_total, nb_correct