-#!/usr/bin/env python
+!/usr/bin/env python
# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
def compute_error(self, model, split="train", nb_to_use=-1):
nb_total, nb_correct = 0, 0
count = torch.zeros(
- self.width * self.height, self.width * self.height, device=self.device
+ self.width * self.height, self.width * self.height, device=self.device, dtype=torch.int64
)
for input in task.batches(split, nb_to_use):
result = input.clone()
)
if count is not None:
+ proportion_optimal = count.diagonal().sum().float() / count.sum()
+ log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
with open(
os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
) as f: