projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
main.py
diff --git
a/main.py
b/main.py
index
0323d02
..
14b1bc3
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-109,7
+109,7
@@
parser.add_argument("--snake_length", type=int, default=200)
##############################
# Snake options
##############################
# Snake options
-parser.add_argument("--stack_nb_steps", type=int, default=
25
)
+parser.add_argument("--stack_nb_steps", type=int, default=
100
)
parser.add_argument("--stack_nb_stacks", type=int, default=1)
parser.add_argument("--stack_nb_stacks", type=int, default=1)
@@
-166,9
+166,9
@@
default_args = {
"nb_test_samples": 10000,
},
"stack": {
"nb_test_samples": 10000,
},
"stack": {
- "nb_epochs":
2
5,
+ "nb_epochs": 5,
"batch_size": 25,
"batch_size": 25,
- "nb_train_samples": 10000,
+ "nb_train_samples": 10000
0
,
"nb_test_samples": 1000,
},
}
"nb_test_samples": 1000,
},
}
@@
-892,6
+892,13
@@
class TaskStack(Task):
nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
)
nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
)
+ mask = self.test_input.clone()
+ stack.remove_poped_values(mask,self.nb_stacks)
+ mask=(mask!=self.test_input)
+ counts = self.test_stack_counts.flatten()[mask.flatten()]
+ counts=F.one_hot(counts).sum(0)
+ log_string(f"stack_count {counts}")
+
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def batches(self, split="train", nb_to_use=-1, desc=None):
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def batches(self, split="train", nb_to_use=-1, desc=None):