projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
40f25d4
)
Update
master
author
François Fleuret
<francois@fleuret.org>
Fri, 7 Apr 2023 21:35:12 +0000
(23:35 +0200)
committer
François Fleuret
<francois@fleuret.org>
Fri, 7 Apr 2023 21:35:12 +0000
(23:35 +0200)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
e69f151
..
5abe39b
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-238,9
+238,9
@@
def oneshot_trace_loss(mazes, output, policies, height, width):
def oneshot(model, learning_rate_scheduler, task):
t = model.training
model.eval()
def oneshot(model, learning_rate_scheduler, task):
t = model.training
model.eval()
- mazes = task.test_input[:
32
].clone()
+ mazes = task.test_input[:
48
].clone()
mazes[:, task.height * task.width :] = 0
mazes[:, task.height * task.width :] = 0
- policies = task.test_policies[:
32
]
+ policies = task.test_policies[:
48
]
targets = maze.stationary_densities(
mazes[:, : task.height * task.width].view(-1, task.height, task.width),
policies.view(-1, 4, task.height, task.width),
targets = maze.stationary_densities(
mazes[:, : task.height * task.width].view(-1, task.height, task.width),
policies.view(-1, 4, task.height, task.width),
@@
-253,7
+253,7
@@
def oneshot(model, learning_rate_scheduler, task):
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
- paths = task.test_input[:
32
, task.height * task.width :].reshape(
+ paths = task.test_input[:
48
, task.height * task.width :].reshape(
-1, task.height, task.width
)
filename = f"oneshot.png"
-1, task.height, task.width
)
filename = f"oneshot.png"
@@
-335,8
+335,8
@@
def oneshot_old(gpt, learning_rate_scheduler, task):
)
# -------------------
)
# -------------------
- mazes = task.test_input[:
32
, : task.height * task.width]
- policies = task.test_policies[:
32
]
+ mazes = task.test_input[:
48
, : task.height * task.width]
+ policies = task.test_policies[:
48
]
output_gpt = eval_mygpt(
gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
)
output_gpt = eval_mygpt(
gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
)
@@
-579,7
+579,7
@@
class TaskMaze(Task):
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
- input = self.test_input[:
32
]
+ input = self.test_input[:
48
]
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1