parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
-parser.add_argument("--result_dir", type=str, default="results_default")
+parser.add_argument("--result_dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--stack_nb_stacks", type=int, default=1)
-parser.add_argument("--stack_nb_digits", type=int, default=1)
+parser.add_argument("--stack_nb_digits", type=int, default=3)
+
+parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
######################################################################
default_args = {
"picoclvr": {
+ "result_dir": "results_picoclvr",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
+ "result_dir": "results_mnist",
"nb_epochs": 25,
"batch_size": 10,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"maze": {
+ "result_dir": "results_maze",
"nb_epochs": 25,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"snake": {
+ "result_dir": "results_snake",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
+ "result_dir": "results_stack",
"nb_epochs": 5,
"batch_size": 25,
"nb_train_samples": 100000,
nb_steps,
nb_stacks,
nb_digits,
+ fraction_values_for_train=None,
device=torch.device("cpu"),
):
self.batch_size = batch_size
self.nb_digits = nb_digits
self.device = device
+ if fraction_values_for_train is None:
+ values_for_train = None
+ values_for_test = None
+ else:
+ all = torch.randperm(10**nb_digits)
+ nb_for_train = int(all.size(0) * fraction_values_for_train)
+ values_for_train = all[:nb_for_train]
+ values_for_test = all[nb_for_train:]
+
self.train_input, self.train_stack_counts = stack.generate_sequences(
- nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device
+ nb_train_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_train,
+ self.device,
)
self.test_input, self.test_stack_counts = stack.generate_sequences(
- nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device
+ nb_test_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_test,
+ self.device,
)
mask = self.test_input.clone()
)
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- input = self.test_input[:10, :20]
+ l = 50
+ l = l - l % (1 + self.nb_digits)
+ input = self.test_input[:10, :l]
result = input.clone()
stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
nb_steps=args.stack_nb_steps,
nb_stacks=args.stack_nb_stacks,
nb_digits=args.stack_nb_digits,
+ fraction_values_for_train=args.stack_fraction_values_for_train,
device=device,
)