######################################################################
+def random_order(result, fixed_len):
+ order = torch.rand(result.size(), device=result.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+ return order.sort(1).indices
+
+
+def shuffle(x, order, reorder=False):
+ if x.dim() == 3:
+ order = order.unsqueeze(-1).expand(-1, -1, x.size(-1))
+ if reorder:
+ y = x.new(x.size())
+ y.scatter_(1, order, x)
+ return y
+ else:
+ return x.gather(1, order)
+
+
# ar_mask is a Boolean matrix of same shape as input, with 1s on the
# tokens that should be generated
-def masked_inplace_autoregression(model, batch_size, input, ar_mask):
+def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
# Needed to initialize the model's cache
- model(mygpt.BracketedSequence(input, 0, i.min()))
+ model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
for s in range(i.min(), i.max() + 1):
- output = model(mygpt.BracketedSequence(input, s, 1)).x
+ output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
logits = output[:, s]
if args.deterministic_synthesis:
t_next = logits.argmax(1)
for input in task.batches(split=split):
input = input.to(device)
-
- output = model(mygpt.BracketedSequence(input)).x
+ order = random_order(input, task.height * task.width)
+ input = shuffle(input, order)
+ output = model(mygpt.BracketedSequence(input), order=order).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_loss += loss.item() * input.size(0)
nb_samples += input.size(0)
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
- output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+ order = random_order(input, task.height * task.width)
+ x = shuffle(mazes, order)
+ x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+ output_gpt = shuffle(x, order, reorder=True)
output = model(output_gpt)
loss = compute_loss(mazes, output, policies, task.height, task.width)
acc_test_loss, nb_test_samples = 0, 0
for mazes, policies in task.policy_batches(split="test"):
- output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+ order = random_order(input, task.height * task.width)
+ x = shuffle(mazes, order)
+ x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+ output_gpt = shuffle(x, order, reorder=True)
output = model(output_gpt)
loss = compute_loss(mazes, output, policies, task.height, task.width)
acc_test_loss += loss.item() * mazes.size(0)
# -------------------
mazes = task.test_input[:32, : task.height * task.width]
policies = task.test_policies[:32]
- output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+ order = random_order(input, task.height * task.width)
+ x = shuffle(mazes, order)
+ x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+ output_gpt = shuffle(x, order, reorder=True)
output = model(output_gpt)
if args.oneshot_output == "policy":
targets = policies.permute(0, 2, 1)
class Task:
- def batches(self, split="train"):
+ def batches(self, split="train", nb_to_use=-1, desc=None):
pass
def vocabulary_size(self):
self.nb_codes = self.train_input.max() + 1
- def batches(self, split="train", nb_to_use=-1):
+ def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
if nb_to_use > 0:
input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
):
yield batch
- def policy_batches(self, split="train", nb_to_use=-1):
+ def policy_batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
policies = self.train_policies if split == "train" else self.test_policies
input = input[:nb_to_use]
policies = policies[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
zip(input.split(self.batch_size), policies.split(self.batch_size)),
dynamic_ncols=True,
- desc=f"epoch-{split}",
+ desc=desc,
):
yield batch
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result *= 1 - ar_mask
- masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+ order = random_order(result, self.height * self.width)
+ masked_inplace_autoregression(
+ model, self.batch_size, result, ar_mask, order=order
+ )
+ result = shuffle(result, order, reorder=True)
mazes, paths = self.seq2map(result)
nb_correct += maze.path_correctness(mazes, paths).long().sum()
nb_total += mazes.size(0)
for input in task.batches(split="train"):
input = input.to(device)
- output = model(mygpt.BracketedSequence(input)).x
+ order = random_order(input, task.height * task.width)
+ input = shuffle(input, order)
+ output = model(mygpt.BracketedSequence(input), order=order).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
# [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
- def forward(self, bs, order=None):
+ def forward(self, bs, order): # NxTxD, T
if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ t = (
+ torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[
+ :, None
+ ]
+ - 1
+ )
+ j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[
None, :
]
k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ pe = (
+ torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ .unsqueeze(0)
+ .expand(bs.x.size(0), -1, -1)
)
- if order is not None:
- self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe))
+ order_output = order + 1
+ order_input = torch.cat(
+ (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
+ )
+
+ self.pe = torch.cat(
+ (
+ pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
+ pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ ),
+ ),
+ 2,
+ )
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
)
bs.x = self.cache_y
def forward(self, bs, mode="standard", order=None):
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- if order is not None:
- order = F.pad(order + 1, (1, -1))
+ if order is None:
+ order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as(
+ bs.x
+ )
bs = self.embedding(bs)
bs = self.pe(bs, order)
r += [bs.slice()]
bs = BracketedSequence(torch.cat(r, -1))
else:
- raise ValueError
+ raise ValueError(f"{mode=}")
return bs