- self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
-
- test_mazes, test_paths, _ = maze.create_maze_data(
- nb_test_samples,
- height=height,
- width=width,
- nb_walls=nb_walls,
- progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
- )
- self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def compute_error(self, model, split="train", nb_to_use=-1):
- nb_total, nb_correct = 0, 0
- count = torch.zeros(
- self.width * self.height,
- self.width * self.height,
- device=self.device,
- dtype=torch.int64,
- )
- for input in tqdm.tqdm(
- task.batches(split, nb_to_use),
- dynamic_ncols=True,
- desc=f"test-mazes",
- ):
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
- ar_mask[:, self.height * self.width :] = 1
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- progress_bar_desc=None,
- device=self.device,
- )
- mazes, paths = self.seq2map(result)
- path_correctness = maze.path_correctness(mazes, paths)
- nb_correct += path_correctness.long().sum()
- nb_total += mazes.size(0)
-
- optimal_path_lengths = (
- (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- predicted_path_lengths = (
- (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- optimal_path_lengths = optimal_path_lengths[path_correctness]
- predicted_path_lengths = predicted_path_lengths[path_correctness]
- count[optimal_path_lengths, predicted_path_lengths] += 1
-
- if count.max() == 0:
- count = None
- else:
- count = count[
- : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
- ]
-
- return nb_total, nb_correct, count
-
- def produce_results(self, n_epoch, model):
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- train_nb_total, train_nb_correct, count = self.compute_error(
- model, "train", nb_to_use=1000
- )
- log_string(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct, count = self.compute_error(
- model, "test", nb_to_use=1000
- )
- log_string(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- if count is not None:
- proportion_optimal = count.diagonal().sum().float() / count.sum()
- log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
- with open(
- os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
- ) as f:
- for i in range(count.size(0)):
- for j in range(count.size(1)):
- eol = " " if j < count.size(1) - 1 else "\n"
- f.write(f"{count[i,j]}{eol}")
-
- input = self.test_input[:48]
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
- ar_mask[:, self.height * self.width :] = 1
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model, self.batch_size, result, ar_mask, device=self.device
- )
-
- mazes, paths = self.seq2map(input)
- _, predicted_paths = self.seq2map(result)
-
- filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
- maze.save_image(
- filename,
- mazes=mazes,
- target_paths=paths,
- predicted_paths=predicted_paths,
- path_correct=maze.path_correctness(mazes, predicted_paths),
- path_optimal=maze.path_optimality(paths, predicted_paths),
- )
- log_string(f"wrote {filename}")
-
- model.train(t)