-####################
-
-import problems
-
-
-class SandBox(Task):
- def __init__(
- self,
- problem,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- logger=None,
- device=torch.device("cpu"),
- max_nb_codes=1024,
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
- self.problem = problem
-
- self.train_input, self.train_ar_mask = self.problem.generate_sequences(
- nb_train_samples
- )
- self.test_input, self.test_ar_mask = self.problem.generate_sequences(
- nb_test_samples
- )
-
- self.train_input, self.train_ar_mask = self.train_input.to(
- device
- ), self.train_ar_mask.to(device)
- self.test_input, self.test_ar_mask = self.test_input.to(
- device
- ), self.test_ar_mask.to(device)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- # A bit of paranoia never hurts
- assert (
- self.nb_codes <= max_nb_codes
- and self.train_input.min() >= 0
- and self.test_input.min() >= 0
- and tuple(self.train_ar_mask.unique()) == (0, 1)
- and tuple(self.test_ar_mask.unique()) == (0, 1)
- )
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
- ):
- def compute_accuracy(input, ar_mask, logger=None):
- input, ar_mask = input[:nmax], ar_mask[:nmax]
- result = input.clone() * (1 - ar_mask)
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- if logger is not None:
- for sp, st in zip(result[:10], input[:10]):
- logger(
- f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}"
- )
- logger(
- f" {n_epoch} ground truth {self.problem.seq2str(st)}"
- )
-
- nb_total = ar_mask.sum().item()
- nb_correct = ((result == input).long() * ar_mask).sum().item()
-
- return nb_total, nb_correct
-
- train_nb_total, train_nb_correct = compute_accuracy(
- self.train_input, self.train_ar_mask
- )
-
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct = compute_accuracy(
- self.test_input, self.test_ar_mask, logger
- )
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- if save_attention_image is None:
- logger("no save_attention_image (is pycairo installed?)")
- else:
- for k in range(10):
- ns = torch.randint(self.test_input.size(0), (1,)).item()
- input = self.test_input[ns : ns + 1].clone()
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- model.record_attention(True)
- model(BracketedSequence(input))
- model.train(t)
- ram = model.retrieve_attention()
- model.record_attention(False)
-
- tokens_output = [c for c in self.problem.seq2str(input[0])]
- tokens_input = ["n/a"] + tokens_output[:-1]
- for n_head in range(ram[0].size(1)):
- filename = os.path.join(
- result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
- )
- attention_matrices = [m[0, n_head] for m in ram]
- save_attention_image(
- filename,
- tokens_input,
- tokens_output,
- attention_matrices,
- k_top=10,
- # min_total_attention=0.9,
- token_gap=12,
- layer_gap=50,
- )
- logger(f"wrote {filename}")