######################################################################
-def compute_perplexity(model, split="train"):
+def compute_perplexity(model, fixed_len, split="train"):
with torch.autograd.no_grad():
t = model.training
model.eval()
for input in task.batches(split=split):
input = input.to(device)
- input, order = shuffle(input, task.height * task.width)
- output = model(mygpt.BracketedSequence(input), order=order).x
+ x, order = shuffle(input, fixed_len)
+ x = model(mygpt.BracketedSequence(x), order=order).x
+ output = reorder(x, order, back=True)
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_loss += loss.item() * input.size(0)
nb_samples += input.size(0)
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result *= 1 - ar_mask
- result, order = shuffle(result, self.height * self.width)
+ x, order = shuffle(result, self.height * self.width)
masked_inplace_autoregression(
- model, self.batch_size, result, ar_mask, order=order
+ model, self.batch_size, x, ar_mask, order=order
)
- result = reorder(result, order, back=True)
+ result = reorder(x, order, back=True)
mazes, paths = self.seq2map(result)
nb_correct += maze.path_correctness(mazes, paths).long().sum()
nb_total += mazes.size(0)
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
- train_perplexity = compute_perplexity(model, split="train")
- test_perplexity = compute_perplexity(model, split="test")
+ train_perplexity = compute_perplexity(
+ model, fixed_len=task.height * task.width, split="train"
+ )
+ test_perplexity = compute_perplexity(
+ model, fixed_len=task.height * task.width, split="test"
+ )
log_string(
f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
for input in task.batches(split="train"):
input = input.to(device)
- input, order = shuffle(input, task.height * task.width)
- output = model(mygpt.BracketedSequence(input), order=order).x
+ x, order = shuffle(input, task.height * task.width)
+ x = model(mygpt.BracketedSequence(x), order=order).x
+ output = reorder(x, order, back=True)
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
optimizer.step()
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- test_perplexity = compute_perplexity(model, split="test")
+ test_perplexity = compute_perplexity(
+ model, fixed_len=task.height * task.width, split="test"
+ )
log_string(
f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"