parser.add_argument("--batch_size", type=int, default=None)
-parser.add_argument("--nb_train_samples", type=int, default=250000)
+parser.add_argument("--nb_train_samples", type=int, default=None)
-parser.add_argument("--nb_test_samples", type=int, default=10000)
+parser.add_argument("--nb_test_samples", type=int, default=None)
parser.add_argument("--optim", type=str, default="adam")
parser.add_argument("--stack_nb_stacks", type=int, default=1)
-parser.add_argument("--stack_nb_values", type=int, default=10)
+parser.add_argument("--stack_nb_digits", type=int, default=1)
######################################################################
# entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
- tqdm.tqdm(
+ batches = tqdm.tqdm(
batches,
dynamic_ncols=True,
desc=progress_bar_desc,
batch_size,
nb_steps,
nb_stacks,
- nb_values,
+ nb_digits,
device=torch.device("cpu"),
):
self.batch_size = batch_size
self.nb_steps = nb_steps
self.nb_stacks = nb_stacks
- self.nb_values = nb_values
+ self.nb_digits = nb_digits
self.device = device
self.train_input, self.train_stack_counts = stack.generate_sequences(
- nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+ nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device
)
self.test_input, self.test_stack_counts = stack.generate_sequences(
- nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+ nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device
)
mask = self.test_input.clone()
- stack.remove_poped_values(mask,self.nb_stacks)
- mask=(mask!=self.test_input)
+ stack.remove_popped_values(mask, self.nb_stacks, self.nb_digits)
+ mask = mask != self.test_input
counts = self.test_stack_counts.flatten()[mask.flatten()]
- counts=F.one_hot(counts).sum(0)
+ counts = F.one_hot(counts).sum(0)
log_string(f"stack_count {counts}")
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def compute_nb_correct(input):
result = input.clone()
- stack.remove_poped_values(result,self.nb_stacks)
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
ar_mask = (result != input).long()
- result *= 1 - ar_mask
-
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
- nb_total = ar_mask.sum()
+ errors = ((result != input).long() * ar_mask).reshape(
+ -1, 1 + self.nb_digits
+ )
+ ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
- nb_correct = (
- (result == input).long() * ar_mask
- ).sum()
+ nb_total = ar_mask.max(1).values.sum()
+ nb_correct = nb_total - errors.max(1).values.sum()
return nb_total, nb_correct
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ input = self.test_input[:10, :20]
+ result = input.clone()
+ stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+ ar_mask = (result != input).long()
+ for n in range(result.size(0)):
+ log_string(
+ f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, device=self.device
+ )
+ for n in range(result.size(0)):
+ log_string(
+ f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+ )
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
model.train(t)
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
- nb_steps = args.stack_nb_steps,
- nb_stacks = args.stack_nb_stacks,
- nb_values = args.stack_nb_values,
+ nb_steps=args.stack_nb_steps,
+ nb_stacks=args.stack_nb_stacks,
+ nb_digits=args.stack_nb_digits,
device=device,
)
# CODE_VAL=val + 2 * nb_stacks
-def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
+def generate_sequences(nb, nb_steps, nb_stacks, nb_digits, device=torch.device("cpu")):
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
- result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
- recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
+ result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
+ recorded_stack_counts = torch.zeros(
+ nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
+ )
for t in range(nb_steps):
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
op = op * (stack_counts[k, st] > 0)
- val_push = torch.randint(nb_values, (nb,))
+ val_push = torch.randint(10**nb_digits, (nb,))
val_pop = stack[
k,
st,
(stack_counts[k, st] - 1).clamp(min=0),
]
stack[k, st, stack_counts[k, st]] = val_push
- recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
+ recorded_stack_counts[:, (1 + nb_digits) * t + 1] = stack_counts[k, st]
stack_counts[k[op == 0], st[op == 0]] += 1
stack_counts[k[op == 1], st[op == 1]] -= 1
- result[:, 2 * t] = st * 2 + op
- result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
+ result[:, (1 + nb_digits) * t] = st * 2 + op
+ for d in range(nb_digits):
+ result[:, (1 + nb_digits) * t + 1 + d] = (
+ (op * val_pop + (1 - op) * val_push) // (10**d)
+ ) % 10 + 2 * nb_stacks
return result.to(device), recorded_stack_counts.to(device)
-def remove_poped_values(seq, nb_stacks):
+def remove_popped_values(seq, nb_stacks, nb_digits):
m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
- seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
+ for d in range(nb_digits):
+ k = d + 1
+ seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
-def seq_to_str(seq, show_stack_nb=True,recorded_stack_counts=None):
- assert seq.size(0) % 2 == 0
+def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
+ assert seq.size(0) % (1 + nb_digits) == 0
s = ""
- for t in range(seq.size(0) // 2):
- n_op = seq[2 * t]
- op = f"POP" if n_op % 2 == 1 else f"PSH"
- if show_stack_nb: op+=f"_{n_op//2}"
- if seq[2 * t + 1] == -1:
- val = "?"
- else:
- val = seq[2 * t + 1] - 2 * nb_stacks
+ for t in range(seq.size(0) // (1 + nb_digits)):
+ n_op = seq[(1 + nb_digits) * t]
if t > 0:
s += " "
+ s += f"POP" if n_op % 2 == 1 else f"PSH"
+ if nb_stacks > 1:
+ s += f"_{n_op//2}"
+ for d in range(nb_digits):
+ if seq[(1 + nb_digits) * t + 1 + d] == -1:
+ s += " ?"
+ else:
+ s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
if recorded_stack_counts is not None:
- s += f"[{recorded_stack_counts[2*t+1]}] "
- s += f"{op} {val}"
+ s += f"[{recorded_stack_counts[(1 + nb_digits)*t+1]}] "
return s
######################################################################
if __name__ == "__main__":
- nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
+ nb, nb_steps, nb_stacks, nb_digits = 150000, 10, 1, 1
seq, recorded_stack_counts = generate_sequences(
- nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
+ nb=nb,
+ nb_steps=nb_steps,
+ nb_stacks=nb_stacks,
+ nb_digits=nb_digits,
)
print("-- TRAIN -----------------------------")
for n in range(min(10, seq.size(0))):
# print(seq_to_str(seq[n], recorded_stack_counts[n]))
- print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))
+ print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
print("-- TEST ------------------------------")
- remove_poped_values(seq, nb_stacks)
+ remove_popped_values(seq, nb_stacks, nb_digits)
for n in range(min(10, seq.size(0))):
- print(seq_to_str(seq[n],show_stack_nb=nb_stacks>1))
+ print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))