From 9c4098a744698138e68cf379d2869b17d407c085 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Jun 2023 07:48:36 +0200 Subject: [PATCH] Update. --- main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 0144817..784474f 100755 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +!/usr/bin/env python # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ @@ -622,7 +622,7 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 count = torch.zeros( - self.width * self.height, self.width * self.height, device=self.device + self.width * self.height, self.width * self.height, device=self.device, dtype=torch.int64 ) for input in task.batches(split, nb_to_use): result = input.clone() @@ -676,6 +676,8 @@ class TaskMaze(Task): ) if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%") with open( os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" ) as f: -- 2.39.5