- worlds = torch.randint(nb_colors, (nb, height, width), device=device)
- nb_prior_visits = torch.zeros(nb, height, width, device=device)
-
- # nb x 2
- snake_position = torch.cat(
- (
- torch.randint(height, (nb, 1), device=device),
- torch.randint(width, (nb, 1), device=device),
- ),
- 1,
- )
- snake_direction = torch.randint(4, (nb,), device=device)
- sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
- sequences_prior_visits = torch.zeros(
- nb, 2 * length, device=device, dtype=torch.int64
- )
- i = torch.arange(nb, device=device) # [:,None]
-
- for l in range(length):
- # nb x 3
- snake_next_direction = torch.cat(
- (
- (snake_direction[:, None] - 1) % 4,
- snake_direction[:, None],
- (snake_direction[:, None] + 1) % 4,
- ),
- 1,
- )
-
- # nb x 3
- vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
- vw = snake_next_direction % 2 * (snake_next_direction - 2)
-
- # nb x 3 x 2
- snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
- snake_next_position = snake_position[:, None, :] + snake_next_speed
-
- # nb x 3
- val = torch.logical_and(
- torch.logical_and(
- snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
- ),
- torch.logical_and(
- snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
- ),
- ).float()
- val = (
- # The multiplicative factors bias toward moving forward
- torch.rand_like(val)
- * val
- * torch.tensor([[1.0, 2.0, 1.0]], device=device)
- )
-
- # nb
- j = val.argmax(1)
- snake_direction = snake_next_direction[i, j]
-
- sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
- sequences_prior_visits[:, 2 * l] = nb_prior_visits[
- i, snake_position[:, 0], snake_position[:, 1]
- ]
- if l < prompt_length:
- nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
- sequences[:, 2 * l + 1] = snake_direction
-
- # nb x 2
- snake_position = snake_next_position[i, j]
-
- return sequences, sequences_prior_visits
-
-
-# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
-# exit(0)
-
-
-def snake_solver(input, ar_mask):
- for n in range(input.size(0)):
- i, j, memory = 0, 0, {}
- # print(input[n])
- # print(ar_mask[n])
- for l in range(input.size(1) // 2):
- if ar_mask[n, 2 * l] == 1:
- if memory.get((i, j)) is None:
- input[n, 2 * l] = -1
- else:
- input[n, 2 * l] = memory[(i, j)]
- else:
- # print(f'@3 {memory=}')
- if memory.get((i, j)) is None:
- memory[(i, j)] = input[n, 2 * l]
- else:
- assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
- # print(f'@1 {i=} {j=}')
- d = input[n, 2 * l + 1].item()
- i += (d + 1) % 2 * (d - 1)
- j += d % 2 * (d - 2)
- # print(f'@2 {i=} {j=}')
-
-
-class TaskSnake(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- device=torch.device("cpu"),
- ):
- self.batch_size = batch_size
- self.height = height
- self.width = width
- self.device = device
- self.prompt_length = prompt_length
-
- self.train_input, self.train_prior_visits = generate_snake_sequences(
- nb_train_samples,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- self.device,
- )
- self.test_input, self.test_prior_visits = generate_snake_sequences(
- nb_test_samples,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- self.device,
- )
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(self, n_epoch, model):
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- def compute_nb_correct(input, prior_visits):
- result = input.clone()
- i = torch.arange(result.size(1), device=result.device)[None, :]
- ar_mask = (
- torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
- .long()
- .expand_as(result)
- )
- result *= 1 - ar_mask