)
parser.add_argument(
- "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+ "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack"
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--snake_length", type=int, default=200)
+##############################
+# Snake options
+
+parser.add_argument("--stack_nb_steps", type=int, default=25)
+
+parser.add_argument("--stack_nb_stacks", type=int, default=1)
+
+parser.add_argument("--stack_nb_values", type=int, default=10)
+
######################################################################
args = parser.parse_args()
"picoclvr": {
"nb_epochs": 25,
"batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
},
"mnist": {
"nb_epochs": 25,
"batch_size": 10,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
},
"maze": {
"nb_epochs": 25,
"batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
},
"snake": {
"nb_epochs": 5,
"batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
+ "stack": {
+ "nb_epochs": 25,
+ "batch_size": 25,
+ "nb_train_samples": 10000,
+ "nb_test_samples": 1000,
},
}
######################################################################
+import stack
+
+
+class TaskStack(Task):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ nb_steps,
+ nb_stacks,
+ nb_values,
+ device=torch.device("cpu"),
+ ):
+ self.batch_size = batch_size
+ self.nb_steps = nb_steps
+ self.nb_stacks = nb_stacks
+ self.nb_values = nb_values
+ self.device = device
+
+ self.train_input, self.train_stack_counts = stack.generate_sequences(
+ nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+ )
+
+ self.test_input, self.test_stack_counts = stack.generate_sequences(
+ nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+ )
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(self, n_epoch, model):
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ def compute_nb_correct(input):
+ result = input.clone()
+ stack.remove_poped_values(result,self.nb_stacks)
+ ar_mask = (result != input).long()
+ result *= 1 - ar_mask
+
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, device=self.device
+ )
+
+ nb_total = ar_mask.sum()
+
+ nb_correct = (
+ (result == input).long() * ar_mask
+ ).sum()
+
+ return nb_total, nb_correct
+
+ test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+ log_string(
+ f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ model.train(t)
+
+
+######################################################################
+
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
device=device,
)
+elif args.task == "stack":
+ task = TaskStack(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ nb_steps = args.stack_nb_steps,
+ nb_stacks = args.stack_nb_stacks,
+ nb_values = args.stack_nb_values,
+ device=device,
+ )
+
else:
raise ValueError(f"Unknown task {args.task}")
# CODE_VAL=val + 2 * nb_stacks
-def generate(nb, nb_steps, nb_stacks, nb_values):
+def generate_sequences(nb, nb_steps, nb_stacks, nb_values, device=torch.device("cpu")):
stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
- stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64)
+ stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
k = torch.arange(nb)
result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64)
- depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
+ recorded_stack_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64)
for t in range(nb_steps):
op = torch.randint(2, (nb,))
st = torch.randint(nb_stacks, (nb,))
- op = op * (stack_pointers[k, st] > 0)
+ op = op * (stack_counts[k, st] > 0)
val_push = torch.randint(nb_values, (nb,))
val_pop = stack[
k,
st,
- (stack_pointers[k, st] - 1).clamp(min=0),
+ (stack_counts[k, st] - 1).clamp(min=0),
]
- stack[k, st, stack_pointers[k, st]] = val_push
- depth_counts[:, 2 * t + 1] = stack_pointers[k, st]
- stack_pointers[k[op == 0], st[op == 0]] += 1
- stack_pointers[k[op == 1], st[op == 1]] -= 1
+ stack[k, st, stack_counts[k, st]] = val_push
+ recorded_stack_counts[:, 2 * t + 1] = stack_counts[k, st]
+ stack_counts[k[op == 0], st[op == 0]] += 1
+ stack_counts[k[op == 1], st[op == 1]] -= 1
result[:, 2 * t] = st * 2 + op
result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks
- return result, depth_counts
+ return result.to(device), recorded_stack_counts.to(device)
-def seq_to_str(seq, depth_counts=None):
+def remove_poped_values(seq, nb_stacks):
+ m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
+ seq[:, 1:] = -m[:, :-1] + (1 - m[:, :-1]) * seq[:, 1:]
+
+
+def seq_to_str(seq, recorded_stack_counts=None):
assert seq.size(0) % 2 == 0
s = ""
for t in range(seq.size(0) // 2):
op = seq[2 * t]
- op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}"
- val = seq[2 * t + 1] - 2 * nb_stacks
+ op = f"POP_{op//2}" if op % 2 == 1 else f"PSH_{op//2}"
+ if seq[2 * t + 1] == -1:
+ val = "?"
+ else:
+ val = seq[2 * t + 1] - 2 * nb_stacks
if t > 0:
s += " "
- if depth_counts is not None:
- s += f"[{depth_counts[2*t+1]}] "
+ if recorded_stack_counts is not None:
+ s += f"[{recorded_stack_counts[2*t+1]}] "
s += f"{op} {val}"
return s
if __name__ == "__main__":
nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5
- seq, depth_counts = generate(
+ seq, recorded_stack_counts = generate_sequences(
nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values
)
for n in range(min(10, seq.size(0))):
- print(seq_to_str(seq[n], depth_counts[n]))
+ # print(seq_to_str(seq[n], recorded_stack_counts[n]))
+ print(seq_to_str(seq[n]))
+
+ print("--------------------------------------")
+
+ remove_poped_values(seq, nb_stacks)
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n]))