def generation_order(x, fixed_len):
if args.random_regression_order:
order = torch.rand(x.size(), device=x.device)
- order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=x.device)
order = order.sort(1).indices
else:
order = (
def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
- for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
+ for input, ar_mask, order in zip(
+ input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size)
+ ):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
# Needed to initialize the model's cache
######################################################################
-def compute_perplexity(model, fixed_len, split="train"):
+def compute_perplexity(model, task, fixed_len, split="train"):
with torch.autograd.no_grad():
t = model.training
model.eval()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result *= 1 - ar_mask
- masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+ x, order = shuffle(result, self.height * self.width)
+ masked_inplace_autoregression(
+ model, self.batch_size, x, ar_mask, order=order
+ )
+ result = reorder(x, order, back=True)
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="train"
+ model, task, fixed_len=task.height * task.width, split="train"
)
test_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="test"
+ model, task, fixed_len=task.height * task.width, split="test"
)
log_string(
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
test_perplexity = compute_perplexity(
- model, fixed_len=task.height * task.width, split="test"
+ model, task, fixed_len=task.height * task.width, split="test"
)
log_string(