parser.add_argument("--batch_size", type=int, default=None)
-parser.add_argument("--nb_train_samples", type=int, default=250000)
+parser.add_argument("--nb_train_samples", type=int, default=None)
-parser.add_argument("--nb_test_samples", type=int, default=10000)
+parser.add_argument("--nb_test_samples", type=int, default=None)
parser.add_argument("--optim", type=str, default="adam")
##############################
# Snake options
-parser.add_argument("--stack_nb_steps", type=int, default=25)
+parser.add_argument("--stack_nb_steps", type=int, default=100)
parser.add_argument("--stack_nb_stacks", type=int, default=1)
-parser.add_argument("--stack_nb_values", type=int, default=10)
+parser.add_argument("--stack_nb_digits", type=int, default=3)
+
+parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
######################################################################
"nb_test_samples": 10000,
},
"stack": {
- "nb_epochs": 25,
+ "nb_epochs": 5,
"batch_size": 25,
- "nb_train_samples": 10000,
+ "nb_train_samples": 100000,
"nb_test_samples": 1000,
},
}
# entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
- tqdm.tqdm(
+ batches = tqdm.tqdm(
batches,
dynamic_ncols=True,
desc=progress_bar_desc,
batch_size,
nb_steps,
nb_stacks,
- nb_values,
+ nb_digits,
+ fraction_values_for_train=None,
device=torch.device("cpu"),
):
self.batch_size = batch_size
self.nb_steps = nb_steps
self.nb_stacks = nb_stacks
- self.nb_values = nb_values
+ self.nb_digits = nb_digits
self.device = device
+ if fraction_values_for_train is None:
+ values_for_train = None
+ values_for_test = None
+ else:
+ all = torch.randperm(10**nb_digits)
+ nb_for_train = int(all.size(0) * fraction_values_for_train)
+ values_for_train = all[:nb_for_train]
+ values_for_test = all[nb_for_train:]
+
self.train_input, self.train_stack_counts = stack.generate_sequences(
- nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+ nb_train_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_train,
+ self.device,
)
self.test_input, self.test_stack_counts = stack.generate_sequences(
- nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+ nb_test_samples,
+ nb_steps,
+ nb_stacks,
+ nb_digits,
+ values_for_test,
+ self.device,
)
+ mask = self.test_input.clone()
+ stack.remove_popped_values(mask, self.nb_stacks, self.nb_digits)
+ mask = mask != self.test_input
+ counts = self.test_stack_counts.flatten()[mask.flatten()]
+ counts = F.one_hot(counts).sum(0)
+ log_string(f"stack_count {counts}")
+
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def batches(self, split="train", nb_to_use=-1, desc=None):
def compute_nb_correct(input):
result = input.clone()
- stack.remove_poped_values(result,self.nb_stacks)
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
- result *= 1 - ar_mask
-
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
- nb_total = ar_mask.sum()
+ errors = ((result != input).long() * ar_mask).reshape(
+ -1, 1 + self.nb_digits
+ )
+ ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
- nb_correct = (
- (result == input).long() * ar_mask
- ).sum()
+ nb_total = ar_mask.max(1).values.sum()
+ nb_correct = nb_total - errors.max(1).values.sum()
return nb_total, nb_correct
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ input = self.test_input[:10, :50]
+ result = input.clone()
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+ ar_mask = (result != input).long()
+ for n in range(result.size(0)):
+ log_string(
+ f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, device=self.device
+ )
+ for n in range(result.size(0)):
+ log_string(
+ f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
model.train(t)
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
- nb_steps = args.stack_nb_steps,
- nb_stacks = args.stack_nb_stacks,
- nb_values = args.stack_nb_values,
+ nb_steps=args.stack_nb_steps,
+ nb_stacks=args.stack_nb_stacks,
+ nb_digits=args.stack_nb_digits,
+ fraction_values_for_train=args.stack_fraction_values_for_train,
device=device,
)