projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
main.py
diff --git
a/main.py
b/main.py
index
9679236
..
7cb8d4f
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-102,7
+102,7
@@
parser.add_argument("--snake_width", type=int, default=8)
parser.add_argument("--snake_nb_colors", type=int, default=5)
parser.add_argument("--snake_nb_colors", type=int, default=5)
-parser.add_argument("--snake_length", type=int, default=
4
00)
+parser.add_argument("--snake_length", type=int, default=
2
00)
######################################################################
######################################################################
@@
-143,8
+143,8
@@
default_args = {
"batch_size": 25,
},
"snake": {
"batch_size": 25,
},
"snake": {
- "nb_epochs":
2
5,
- "batch_size": 2
0
,
+ "nb_epochs": 5,
+ "batch_size": 2
5
,
},
}
},
}
@@
-689,7
+689,7
@@
class TaskSnake(Task):
self.device = device
self.prompt_length = prompt_length
self.device = device
self.prompt_length = prompt_length
- self.train_input, self.train_prior_visits = snake.generate_sequences(
+ self.train_input, self.train_prior_visits
, _, _
= snake.generate_sequences(
nb_train_samples,
height,
width,
nb_train_samples,
height,
width,
@@
-698,7
+698,7
@@
class TaskSnake(Task):
prompt_length,
self.device,
)
prompt_length,
self.device,
)
- self.test_input, self.test_prior_visits = snake.generate_sequences(
+ self.test_input, self.test_prior_visits
, _, _
= snake.generate_sequences(
nb_test_samples,
height,
width,
nb_test_samples,
height,
width,