projects
/
beaver.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[beaver.git]
/
beaver.py
diff --git
a/beaver.py
b/beaver.py
index
920a446
..
a289867
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-75,11
+75,11
@@
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
##############################
# maze options
##############################
# maze options
-parser.add_argument("--
world
_height", type=int, default=13)
+parser.add_argument("--
maze
_height", type=int, default=13)
-parser.add_argument("--
world
_width", type=int, default=21)
+parser.add_argument("--
maze
_width", type=int, default=21)
-parser.add_argument("--
world
_nb_walls", type=int, default=15)
+parser.add_argument("--
maze
_nb_walls", type=int, default=15)
######################################################################
######################################################################
@@
-262,7
+262,13
@@
class TaskMaze(Task):
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
- maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+ maze.save_image(
+ f"result_{n_epoch:04d}.png",
+ mazes,
+ paths,
+ predicted_paths,
+ maze.path_correctness(mazes, predicted_paths),
+ )
model.train(t)
model.train(t)
@@
-276,9
+282,9
@@
task = TaskMaze(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
- height=args.
world
_height,
- width=args.
world
_width,
- nb_walls=args.
world
_nb_walls,
+ height=args.
maze
_height,
+ width=args.
maze
_width,
+ nb_walls=args.
maze
_nb_walls,
device=device,
)
device=device,
)