parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
-parser.add_argument("--result_dir", type=str, default="results_default")
+parser.add_argument("--result_dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=0)
default_args = {
"picoclvr": {
+ "result_dir": "results_picoclvr",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
+ "result_dir": "results_mnist",
"nb_epochs": 25,
"batch_size": 10,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"maze": {
+ "result_dir": "results_maze",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"snake": {
+ "result_dir": "results_snake",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
+ "result_dir": "results_stack",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 100000,
)
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- input = self.test_input[:10, :50]
+ l = 50
+ l = l - l % (1 + self.nb_digits)
+ input = self.test_input[:10, :l]
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()