logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
- if save_attention_image is not None:
- for k in range(10):
- ns = torch.randint(self.test_input.size(0), (1,)).item()
- input = self.test_input[ns : ns + 1].clone()
+ ##############################
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- # model.record_attention(True)
- model(BracketedSequence(input))
- model.train(t)
- # ram = model.retrieve_attention()
- # model.record_attention(False)
+ input, ar_mask = self.test_input[:64], self.test_ar_mask[:64]
+ result = input.clone() * (1 - ar_mask)
- # tokens_output = [c for c in self.problem.seq2str(input[0])]
- # tokens_input = ["n/a"] + tokens_output[:-1]
- # for n_head in range(ram[0].size(1)):
- # filename = os.path.join(
- # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
- # )
- # attention_matrices = [m[0, n_head] for m in ram]
- # save_attention_image(
- # filename,
- # tokens_input,
- # tokens_output,
- # attention_matrices,
- # k_top=10,
- ##min_total_attention=0.9,
- # token_gap=12,
- # layer_gap=50,
- # )
- # logger(f"wrote {filename}")
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
+ )
+
+ img = world.sample2img(result.to("cpu"), self.height, self.width)
+
+ image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+ logger(f"wrote {image_name}")
######################################################################