f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
- if save_attention_image is not None:
+ if save_attention_image is None:
+ logger("no save_attention_image (is pycairo installed?)")
+ else:
for k in range(10):
ns = torch.randint(self.test_input.size(0), (1,)).item()
input = self.test_input[ns : ns + 1].clone()
f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
)
- if save_attention_image is not None:
+ if save_attention_image is None:
+ logger("no save_attention_image (is pycairo installed?)")
+ else:
ns = torch.randint(self.test_input.size(0), (1,)).item()
input = self.test_input[ns : ns + 1].clone()
last = (input != self.t_nul).max(0).values.nonzero().max() + 3