nb_train_samples,
nb_test_samples,
batch_size,
- height,
- width,
+ size,
logger=None,
device=torch.device("cpu"),
):
self.device = device
self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(height=height, width=width)
+ self.grid_factory = grid.GridFactory(size=size)
if logger is not None:
logger(
nb_total = ar_mask.sum().item()
nb_correct = ((correct == result).long() * ar_mask).sum().item()
- logger(f"test_performance {nb_total=} {nb_correct=}")
- logger(f"main_test_accuracy {nb_correct / nb_total}")
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
######################################################################