projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
777e637
)
Update
author
François Fleuret
<francois@fleuret.org>
Thu, 23 Mar 2023 08:48:55 +0000
(09:48 +0100)
committer
François Fleuret
<francois@fleuret.org>
Thu, 23 Mar 2023 08:48:55 +0000
(09:48 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
dca97cc
..
bd17365
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-312,15
+312,17
@@
def oneshot(gpt, task):
scores = scores.reshape(-1, task.height, task.width)
mazes = mazes.reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
scores = scores.reshape(-1, task.height, task.width)
mazes = mazes.reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
+ filename = (
+ f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
+ )
maze.save_image(
maze.save_image(
- os.path.join(
- args.result_dir,
- f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
- ),
+ os.path.join(args.result_dir, filename),
mazes=mazes,
score_paths=scores,
score_truth=targets,
)
mazes=mazes,
score_paths=scores,
score_truth=targets,
)
+ log_string(f"wrote {filename}")
+
# -------------------
gpt.train(t)
# -------------------
gpt.train(t)
@@
-471,13
+473,15
@@
class TaskMaze(Task):
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
+ filename = f"result_{n_epoch:04d}.png"
maze.save_image(
maze.save_image(
- os.path.join(args.result_dir, f
"result_{n_epoch:04d}.png"
),
+ os.path.join(args.result_dir, f
ilename
),
mazes=mazes,
target_paths=paths,
predicted_paths=predicted_paths,
path_correct=maze.path_correctness(mazes, predicted_paths),
)
mazes=mazes,
target_paths=paths,
predicted_paths=predicted_paths,
path_correct=maze.path_correctness(mazes, predicted_paths),
)
+ log_string(f"wrote {filename}")
model.train(t)
model.train(t)