projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
d8ec2eb
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 23 Jun 2023 05:48:36 +0000
(07:48 +0200)
committer
François Fleuret
<francois@fleuret.org>
Fri, 23 Jun 2023 05:48:36 +0000
(07:48 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
0144817
..
784474f
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-1,4
+1,4
@@
-
#
!/usr/bin/env python
+!/usr/bin/env python
# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
@@
-622,7
+622,7
@@
class TaskMaze(Task):
def compute_error(self, model, split="train", nb_to_use=-1):
nb_total, nb_correct = 0, 0
count = torch.zeros(
def compute_error(self, model, split="train", nb_to_use=-1):
nb_total, nb_correct = 0, 0
count = torch.zeros(
- self.width * self.height, self.width * self.height, device=self.device
+ self.width * self.height, self.width * self.height, device=self.device
, dtype=torch.int64
)
for input in task.batches(split, nb_to_use):
result = input.clone()
)
for input in task.batches(split, nb_to_use):
result = input.clone()
@@
-676,6
+676,8
@@
class TaskMaze(Task):
)
if count is not None:
)
if count is not None:
+ proportion_optimal = count.diagonal().sum().float() / count.sum()
+ log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
with open(
os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
) as f:
with open(
os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
) as f: