From: François Fleuret Date: Fri, 23 Jun 2023 05:48:36 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;ds=inline;h=9c4098a744698138e68cf379d2869b17d407c085;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 0144817..784474f 100755 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +!/usr/bin/env python # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ @@ -622,7 +622,7 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 count = torch.zeros( - self.width * self.height, self.width * self.height, device=self.device + self.width * self.height, self.width * self.height, device=self.device, dtype=torch.int64 ) for input in task.batches(split, nb_to_use): result = input.clone() @@ -676,6 +676,8 @@ class TaskMaze(Task): ) if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%") with open( os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" ) as f: