# torch.backends.cuda.matmul.allow_tf23
# torch.autocast(torch.bfloat16)
-import math, sys, argparse, time, tqdm, itertools, os
+import math, sys, argparse, time, tqdm, os
import torch, torchvision
from torch import nn
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+ description="An implementation of GPT with cache.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--task", type=str, default="picoclvr")
parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument("--batch_size", type=int, default=25)
+parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--nb_train_samples", type=int, default=250000)
######################################################################
+default_args = {
+ "picoclvr": {
+ "batch_size": 25,
+ },
+ "mnist": {
+ "batch_size": 10,
+ },
+ "maze": {
+ "batch_size": 25,
+ },
+ "snake": {
+ "batch_size": 20,
+ },
+}
+
+if args.task in default_args:
+ for k, v in default_args[args.task].items():
+ if getattr(args, k) is None:
+ setattr(args, k, v)
+
+######################################################################
+
def log_string(s):
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
nb, height, width, nb_colors, length, device=torch.device("cpu")
):
worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+ nb_prior_visits = torch.zeros(nb, height, width, device=device)
+
# nb x 2
snake_position = torch.cat(
(
)
snake_direction = torch.randint(4, (nb,), device=device)
sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+ sequences_prior_visits = torch.zeros(
+ nb, 2 * length, device=device, dtype=torch.int64
+ )
i = torch.arange(nb, device=device) # [:,None]
for l in range(length):
),
).float()
val = (
- torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device)
+ # The multiplicative factors bias toward moving forward
+ torch.rand_like(val)
+ * val
+ * torch.tensor([[1.0, 2.0, 1.0]], device=device)
)
# nb
snake_direction = snake_next_direction[i, j]
sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+ sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+ i, snake_position[:, 0], snake_position[:, 1]
+ ]
+ nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
sequences[:, 2 * l + 1] = snake_direction
# nb x 2
snake_position = snake_next_position[i, j]
- return sequences, worlds
+ return sequences, sequences_prior_visits
# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
self.width = width
self.device = device
- self.train_input, self.train_worlds = generate_snake_sequences(
+ self.train_input, self.train_prior_visits = generate_snake_sequences(
nb_train_samples, height, width, nb_colors, length, self.device
)
- self.test_input, self.test_worlds = generate_snake_sequences(
+ self.test_input, self.test_prior_visits = generate_snake_sequences(
nb_test_samples, height, width, nb_colors, length, self.device
)
t = model.training
model.eval()
- def compute_nb_correct(input):
+ def compute_nb_correct(input, prior_visits):
result = input.clone()
- i = torch.arange(result.size(1), device=result.device)
- ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
- None, :
- ].long()
+ i = torch.arange(result.size(1), device=result.device)[None, :]
+ ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0).long()
result *= 1 - ar_mask
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
- nb_total = ar_mask.sum() * input.size(0)
- nb_correct = ((result == input).long() * ar_mask).sum()
+ nb_total = (
+ (prior_visits > 0) * ar_mask
+ ).sum()
+
+ nb_correct = (
+ (result == input).long() * (prior_visits > 0) * ar_mask
+ ).sum()
# nb_total = result.size(0)
# nb_correct = ((result - input).abs().sum(1) == 0).sum()
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
+ train_nb_total, train_nb_correct = compute_nb_correct(
+ self.train_input, self.train_prior_visits
+ )
log_string(
f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
+ test_nb_total, test_nb_correct = compute_nb_correct(
+ self.test_input, self.test_prior_visits
+ )
log_string(
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"