X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=543f04e5628b7276769241922c3770afd5d416f9;hb=08b58304225e044a21419dd30302d985acc1824c;hp=d3be4f87e1984094bbb9524a4c4fb0b3e9cefb6c;hpb=5dad808c1e8e72c40711b0350b1c1bebee16a446;p=picoclvr.git diff --git a/stack.py b/stack.py index d3be4f8..543f04e 100755 --- a/stack.py +++ b/stack.py @@ -13,53 +13,95 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate(nb, seq_len, nb_stacks, nb_values): - stack = torch.empty(nb, nb_stacks, seq_len, dtype=torch.int64) - stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64) +def generate_sequences( + nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu") +): + stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) + stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) - result = torch.empty(nb, 2 * seq_len, dtype=torch.int64) + result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64) + recorded_stack_counts = torch.zeros( + nb, (1 + nb_digits) * nb_steps, dtype=torch.int64 + ) - for t in range(seq_len): + for t in range(nb_steps): op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) - op = op * (stack_pointers[k, st] > 0) - val_push = torch.randint(nb_values, (nb,)) - # top_val[n,s]=stack[n,stack_pointers[n,s]] - top_values = stack[ + op = op * (stack_counts[k, st] > 0) + if values is None: + val_push = torch.randint(10**nb_digits, (nb,)) + else: + val_push = values[torch.randint(values.size(0), (nb,))] + val_pop = stack[ k, st, - (stack_pointers[k, st] - 1).clamp(min=0), + (stack_counts[k, st] - 1).clamp(min=0), ] - stack[ - k[:, None].expand_as(stack_pointers), - st[:, None].expand_as(stack_pointers), - stack_pointers, - ] = val_push[:, None].expand_as(stack_pointers) - stack_pointers[k[op == 0], st[op == 0]] += 1 - stack_pointers[k[op == 1], st[op == 1]] -= 1 - result[:, 2 * t] = st * 2 + op - result[:, 2 * t + 1] = (op * top_values + (1 - op) * val_push) + 2 * nb_stacks + stack[k, st, stack_counts[k, st]] = val_push + recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] + stack_counts[k[op == 0], st[op == 0]] += 1 + stack_counts[k[op == 1], st[op == 1]] -= 1 + result[:, (1 + nb_digits) * t] = st * 2 + op + for d in range(nb_digits): + result[:, (1 + nb_digits) * t + 1 + d] = ( + (op * val_pop + (1 - op) * val_push) // (10**d) + ) % 10 + 2 * nb_stacks - return result + return result.to(device), recorded_stack_counts.to(device) -def seq_to_str(seq): - assert seq.size(0) % 2 == 0 +def remove_popped_values(seq, nb_stacks, nb_digits): + m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long() + for d in range(nb_digits): + k = d + 1 + seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] + + +def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): + assert seq.size(0) % (1 + nb_digits) == 0 s = "" - for t in range(0, seq.size(0), 2): - op = seq[t] - op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}" - val = seq[t + 1] - 2 * nb_stacks + for t in range(seq.size(0) // (1 + nb_digits)): + n_op = seq[(1 + nb_digits) * t] if t > 0: s += " " - s += f"{op} {val}" + if recorded_stack_counts is not None: + s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " + s += f"POP" if n_op % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n_op//2}" + for d in range(nb_digits): + if seq[(1 + nb_digits) * t + 1 + d] == -1: + s += " ?" + else: + s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" return s ###################################################################### if __name__ == "__main__": - nb, seq_len, nb_stacks, nb_values = 3, 10, 1, 5 - result = generate(nb=nb, seq_len=seq_len, nb_stacks=nb_stacks, nb_values=nb_values) - for n in range(result.size(0)): - print(seq_to_str(result[n])) + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 + seq, recorded_stack_counts = generate_sequences( + nb=nb, + nb_steps=nb_steps, + nb_stacks=nb_stacks, + nb_digits=nb_digits, + ) + + for n in range(min(10, seq.size(0))): + print( + seq_to_str( + seq[n], + nb_stacks=nb_stacks, + nb_digits=nb_digits, + recorded_stack_counts=recorded_stack_counts[n], + ) + ) + # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits)) + + print("-- PREPARED FOR TEST -----------------") + + remove_popped_values(seq, nb_stacks, nb_digits) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))