projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
519b541
)
Update
author
François Fleuret
<francois@fleuret.org>
Fri, 7 Apr 2023 12:26:46 +0000
(14:26 +0200)
committer
François Fleuret
<francois@fleuret.org>
Fri, 7 Apr 2023 12:26:46 +0000
(14:26 +0200)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
6e1eaf4
..
9f8bc41
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-238,6
+238,11
@@
def oneshot(model, learning_rate_scheduler, task):
model.eval()
mazes = task.test_input[:32].clone()
mazes[:, task.height * task.width :] = 0
model.eval()
mazes = task.test_input[:32].clone()
mazes[:, task.height * task.width :] = 0
+ policies = task.test_policies[:32]
+ targets = maze.stationary_densities(
+ mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
output = F.softmax(output, dim=2)
print(f"{output.size()=}")
output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
output = F.softmax(output, dim=2)
print(f"{output.size()=}")
@@
-245,6
+250,7
@@
def oneshot(model, learning_rate_scheduler, task):
-1, task.height, task.width
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
-1, task.height, task.width
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
paths = task.test_input[:32, task.height * task.width :].reshape(
-1, task.height, task.width
)
paths = task.test_input[:32, task.height * task.width :].reshape(
-1, task.height, task.width
)
@@
-252,9
+258,9
@@
def oneshot(model, learning_rate_scheduler, task):
maze.save_image(
os.path.join(args.result_dir, filename),
mazes=mazes,
maze.save_image(
os.path.join(args.result_dir, filename),
mazes=mazes,
- target_paths=paths,
+
#
target_paths=paths,
score_paths=proba_path,
score_paths=proba_path,
-
#
score_truth=targets,
+ score_truth=targets,
)
log_string(f"wrote {filename}")
)
log_string(f"wrote {filename}")