projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
main.py
diff --git
a/main.py
b/main.py
index
8c4b7a1
..
2ed6b6b
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-37,7
+37,7
@@
parser.add_argument(
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
-parser.add_argument("--result_dir", type=str, default=
"results_default"
)
+parser.add_argument("--result_dir", type=str, default=
None
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--seed", type=int, default=0)
@@
-144,30
+144,35
@@
if args.seed >= 0:
default_args = {
"picoclvr": {
default_args = {
"picoclvr": {
+ "result_dir": "results_picoclvr",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
+ "result_dir": "results_mnist",
"nb_epochs": 25,
"batch_size": 10,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"maze": {
"nb_epochs": 25,
"batch_size": 10,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"maze": {
+ "result_dir": "results_maze",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"snake": {
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"snake": {
+ "result_dir": "results_snake",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
+ "result_dir": "results_stack",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 100000,
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 100000,
@@
-968,8
+973,8
@@
class TaskStack(Task):
)
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
)
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- l
=
50
- l
=l-l%(1+
self.nb_digits)
+ l
=
50
+ l
= l - l % (1 +
self.nb_digits)
input = self.test_input[:10, :l]
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
input = self.test_input[:10, :l]
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)